From b5268dc3a0847dce2668265e07ff50d54265b2d8 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Thu, 15 Sep 2022 11:08:16 +0000 Subject: [PATCH] add centripetal text model --- configs/det/det_r18_vd_ct.yml | 107 ++++++ doc/doc_ch/algorithm_det_ct.md | 95 +++++ doc/doc_en/algorithm_det_ct_en.md | 96 +++++ doc/imgs_results/det_res_img623_ct.jpg | Bin 0 -> 140971 bytes ppocr/data/imaug/__init__.py | 1 + ppocr/data/imaug/ct_process.py | 355 ++++++++++++++++++ ppocr/data/imaug/label_ops.py | 26 ++ ppocr/losses/__init__.py | 3 +- ppocr/losses/det_ct_loss.py | 276 ++++++++++++++ ppocr/metrics/__init__.py | 4 +- ppocr/metrics/ct_metric.py | 52 +++ ppocr/modeling/heads/__init__.py | 3 +- ppocr/modeling/heads/det_ct_head.py | 69 ++++ ppocr/modeling/necks/__init__.py | 4 +- ppocr/modeling/necks/ct_fpn.py | 185 +++++++++ ppocr/postprocess/__init__.py | 3 +- ppocr/postprocess/ct_postprocess.py | 154 ++++++++ ppocr/utils/e2e_metric/Deteval.py | 225 ++++++++--- requirements.txt | 1 + .../configs/det_r18_ct/train_infer_python.txt | 53 +++ test_tipc/prepare.sh | 5 + tools/infer/predict_det.py | 8 +- tools/program.py | 2 +- tools/train.py | 15 +- train.sh | 2 +- 25 files changed, 1682 insertions(+), 62 deletions(-) create mode 100644 configs/det/det_r18_vd_ct.yml create mode 100644 doc/doc_ch/algorithm_det_ct.md create mode 100644 doc/doc_en/algorithm_det_ct_en.md create mode 100644 doc/imgs_results/det_res_img623_ct.jpg create mode 100644 ppocr/data/imaug/ct_process.py create mode 100755 ppocr/losses/det_ct_loss.py create mode 100644 ppocr/metrics/ct_metric.py create mode 100644 ppocr/modeling/heads/det_ct_head.py create mode 100644 ppocr/modeling/necks/ct_fpn.py create mode 100755 ppocr/postprocess/ct_postprocess.py create mode 100644 test_tipc/configs/det_r18_ct/train_infer_python.txt diff --git a/configs/det/det_r18_vd_ct.yml b/configs/det/det_r18_vd_ct.yml new file mode 100644 index 00000000..42922dfd --- /dev/null +++ b/configs/det/det_r18_vd_ct.yml @@ -0,0 +1,107 @@ +Global: + use_gpu: true + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/det_ct/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations + eval_batch_step: [0,1000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img623.jpg + save_res_path: ./output/det_ct/predicts_ct.txt + +Architecture: + model_type: det + algorithm: CT + Transform: + Backbone: + name: ResNet_vd + layers: 18 + Neck: + name: CTFPN + Head: + name: CT_Head + in_channels: 512 + hidden_dim: 128 + num_classes: 3 + +Loss: + name: CTLoss + +Optimizer: + name: Adam + lr: #PolynomialDecay + name: Linear + learning_rate: 0.001 + end_lr: 0. + epochs: 600 + step_each_epoch: 1254 + power: 0.9 + +PostProcess: + name: CTPostProcess + box_type: poly + +Metric: + name: CTMetric + main_indicator: f_score + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/total_text/train + label_file_list: + - ./train_data/total_text/train/train.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: RGB + channel_first: False + - CTLabelEncode: # Class handling label + - RandomScale: + - MakeShrink: + - GroupRandomHorizontalFlip: + - GroupRandomRotate: + - GroupRandomCropPadding: + - MakeCentripetalShift: + - ColorJitter: + brightness: 0.125 + saturation: 0.5 + - ToCHWImage: + - NormalizeImage: + - KeepKeys: + keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list + loader: + shuffle: True + drop_last: True + batch_size_per_card: 4 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/total_text/test + label_file_list: + - ./train_data/total_text/test/test.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: RGB + channel_first: False + - CTLabelEncode: # Class handling label + - ScaleAlignedShort: + - NormalizeImage: + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 2 diff --git a/doc/doc_ch/algorithm_det_ct.md b/doc/doc_ch/algorithm_det_ct.md new file mode 100644 index 00000000..ea3522b7 --- /dev/null +++ b/doc/doc_ch/algorithm_det_ct.md @@ -0,0 +1,95 @@ +# CT + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945) +> Tao Sheng, Jie Chen, Zhouhui Lian +> NeurIPS, 2021 + + +在Total-Text文本检测公开数据集上,算法复现效果如下: + +|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接| +| --- | --- | --- | --- | --- | --- | --- | +|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +CT模型使用Total-Text文本检测公开数据集训练得到,数据集下载可参考 [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset), 我们将标签文件转成了paddleocr格式,转换好的标签文件下载参考[train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [text.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt)。 + +请参考[文本检测训练教程](./detection.md)。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。 + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将CT文本检测训练过程中保存的模型,转换成inference model。以基于Resnet18_vd骨干网络,在Total-Text英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar) ),可以使用如下命令进行转换: + +```shell +python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct +``` + +CT文本检测模型推理,可以执行如下命令: + +```shell +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT" +``` + +可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](../imgs_results/det_res_img623_ct.jpg) + + + +### 4.2 C++推理 + +暂不支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@inproceedings{sheng2021centripetaltext, + title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection}, + author={Tao Sheng and Jie Chen and Zhouhui Lian}, + booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, + year={2021} +} +``` diff --git a/doc/doc_en/algorithm_det_ct_en.md b/doc/doc_en/algorithm_det_ct_en.md new file mode 100644 index 00000000..d56b3fc6 --- /dev/null +++ b/doc/doc_en/algorithm_det_ct_en.md @@ -0,0 +1,96 @@ +# CT + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945) +> Tao Sheng, Jie Chen, Zhouhui Lian +> NeurIPS, 2021 + + +On the Total-Text dataset, the text detection result is as follows: + +|Model|Backbone|Configuration|Precision|Recall|Hmean|Download| +| --- | --- | --- | --- | --- | --- | --- | +|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)| + + + +## 2. Environment +Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md). + + + +## 3. Model Training / Evaluation / Prediction + + +The above CT model is trained using the Total-Text text detection public dataset. For the download of the dataset, please refer to [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset). PaddleOCR format annotation download link [train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [test.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt). + + +Please refer to [text detection training tutorial](./detection_en.md). PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models. + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, convert the model saved in the CT text detection training process into an inference model. Taking the model based on the Resnet18_vd backbone network and trained on the Total Text English dataset as example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)), you can use the following command to convert: + +```shell +python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct +``` + +CT text detection model inference, you can execute the following command: + +```shell +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT" +``` + +The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: + +![](../imgs_results/det_res_img623_ct.jpg) + + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@inproceedings{sheng2021centripetaltext, + title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection}, + author={Tao Sheng and Jie Chen and Zhouhui Lian}, + booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, + year={2021} +} +``` diff --git a/doc/imgs_results/det_res_img623_ct.jpg b/doc/imgs_results/det_res_img623_ct.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c5f57d96cca896c70d9e0d33ba80a0177a8aeb9 GIT binary patch literal 140971 zcmeFYcRXC*w>N(D5+Pc2Bhd+>MxRJNdZI;%5+X#69=(hbgfKe6FcO_0M33G@km%7F zz0RnEF@BSLzW4k6-Fu&VU-xa2h*FHFBpMBO^d#(3gd#!b@rmhx&8yYI=DgYiH z0N~+%fU6ncA%Ktf`}lpv|9ub={5}&A5)u#+6A=^tX(VK%#3W=S#Kff6NXf{5AGp6L zu8~vx{^$1~e?N*(L_k19PC`ubr_2A=>8cH&CdFsRA0fcI4d7Gb5m4h@bpaq8okV}g z!x8)Ufrn2(NJLD6BjXybL(L5w`2++wqKR-6<9Y|;?gNC>L^QWVm56C|ElF;>(TRm5 zd?4j~P}xqe_Y=t_{=z+!jQl18BNH<>4=*3TfP|#fJ!u))hkq%nsH&-JJkd8WG!Q-{3m$#3vU)Y=Qh{(55(TPdPDXD4i(=$Hi=H(X@78RFNeW|Xgt*dWn z?C9+3?&?m@zb%62kNaXlpc3-_PoH@PI1 zxKY1Z`@`A)jIq%FBhLQG*uVIi1}F&daFa(s4S<15w#)EVON(C~h}NBkrG*$UV-VZC zBbIw|?^v_U%D5S&l4tYv{g#o&v)Wla7$D77I>~?~^SmIFBF9iJ9)Xcb4iHR|Du!mY z6+t9Vo{Rk$cm-_ndYhlISc{tPhSk5wiW);oT>+tH*ba^QDnb)$id@4yl@2|AHkx2K0@TNj>7e*zkCygSaiXfO&4APTq9RN zXW12C2|m9CIrG?pW0K4LtU;a9-Y#1-&4?efMjrr}r{tDcKxbGaac{G_-7 z2H^Tvz)HqBKJ=20VEqa}d$#WIfY%~yv1?A+S3sg27^w-F(!{>C{l}Xof6(cW4EJKm za3ic^)&A}BzY!pEehdB&0_MUEXp0uOfyBfD@ya|L|--COZbZ^$3+{x!J8eck&Y@IRRT*JaiJ^rYkw5#9eA z^8d;5x61urP$F}zf`{gRp_zyinj!F#3S@QgZ;Ms~Sn=+&<6Qxib(e3hfQ^WM+yVWz zjK3}J{I_jS{$M-*`Cp9rAM6CD?Swe%MEtSNE8vfTXZ|s8@V`O8!ru^3k;8{zAN!wK z29BwJnKAbj@R#@%ARPGQH#%Gak*3%V$k9Xa!TLD(56FdjWG7eubKqf2{C^2Lr-xg1 zh@^>Nd4fN9^!nf7L)af4AV07F#wgsoa2N&9_(hxl6Q%xFgjoEG5R7f{(WWLECyn(# zBSBT$nl2z6+Bi5lc>JG4+zMTn$(7zd`#}=?AGPEUrqJRnh!KJ;_yf?*7`TOrbPg5p zH~0T~Y20u0{C9LYE{ugrOZ-b5|IQF`5EwV5c7$evU!>Rk!9Tf?A0)$9z;6Y3guDLV zz#j!7=a@D6r3;|`4K1LKQI0x58hEM9d4EQyLW5GlkA>g zMacg|a_2aG5OMzJ`oW7+Jdxju_h(u9j{~#>=Vk!d*O+G)e_?-M$x!Bni^F8tABu97 zBh|4g{z6hcRE!IRb#>7yQ0YmrBLf|h&mZge-+>PFn>4QgE24F5oF5;-sUw!m^(({@ zL~syfRlFE?K{hwp8v3C&REhHgTWMSCV(Qry5J7_!x$L?EjyAo`ah%p6v=;I25?=#3 z=!mwnl0V30oR;3yaA$y4hv)o+24f$<#C$BzekfWd^4l@$J%@ND`~Ps2kpQx;fE0oX zt+A$`_sJFt5t_<$fbRcpQ{wQH(7{ut`_Bz}*Y0C9z>hafoEi=A<;nDF+1ue!$|j7C zgiWVyt0%;)1>cOJaBAz#r_%xM=CUX2v9gP*kbq49svoHPw<3wdvhT|M?8_mjU>zCp z$cE{KK#u@WKI6x*kbdcI;1-iSR-c?i&!abASP`BA@c2rz$^Ia(=L~~ZSA+5aUUR-~ zk)Y%()rS+o1RK)%>m%_nchyy;)6M&MLzJ>3BY?O6i-;jrwzC|-k9~|9vmwnIqLyw| zUoLx`LnO0torHJwBv(c{C!uF)aypnG&SD~3T7i-eM;`zr%4SywQc18vA>v9*n+s(C z31CZy7ouACNQab{?UD~r?1%xG0L!Vtl*#Ebo83Yji#jBSswgY0Mp_hP=Xnzj8TXYz zH^D2#K55vVYr~i9M>-^X>lm5^_>^Yb!YR($NAu?badKhx+!1U8FQ($=`Gq;r6;Kby zW(jXE^!V!l3SvRkcp;a3y)rz(9%0#3ML9}g$NB8@)j<+(fV*~#?)-Ln>o{f<+@j@< z^QQUjfRE99fC3+GTK@YM6K4YO<#JB^{hQ-RLI+S$5zIihqdeNccvwb92%{qJHfl42 zDsqQ?U$-eKtEib?)zFJeplCHcm^I^U0ZWbp5h8gc1BTXiku{ss_?0J!-4WGK5*DD~ zA&r+u+BG+nvv(nHO@SbqD!~^shGfT0Be2)5UwB}r=gGgI{V(ogd>h6jSWK^gTIGVx z#>XvLw2#C}M&9YGye}U5(N(gNkTmGE>8@0gI_F>Xe$PL=5Zz2jIFdfC_*j;_WDZId7;YFm{?%jX zfU-R)ba5k^rJgexO6(ROfvQc3H73L>u_>`_9-n#O{A+3)Gu@cG=%2`nPDrIuuL~0k z)mvFQ;@(c9#@z5o(=Xb`H#cvAQHQKPN-XZ1%vGt}!Ivt13fBXYmt;NM635wMP0eF) zGo3(0r+fcTBeaR;!xM+LVoS`4AM+?&CxFVF|G+=)xz6$2kJ&s~2M_&2i;rYJslpEc zx`pT>q3SsViVImnI19x+KZD%Ak0-sWCVK@a@_P&~srlyMRS_!@uvb%+i4ef_2ogqu zDL^LLnO#REBW5-811faW?oA;0M?s$Y~_JR=i@isek$oAD^vQ3$e@_Pw5 zj&le?-Tn2j66t#Opd3YPGLXBrt{}Ud?0-H=P|U$qaigTYfh1pqQwFOa z8B(Lamc|9ul3tgfAbF%m8-uUP+mK@`{X(HWzfC5GP){l}$BI=iF|ar`rKcU{HyXxj zOArl^bt0q@P=0OBTYl#wBHv}$&Y8)izgbNOQxp@hhK;0Ui*^&Tch*VgW@{hj5P%D7 zZ~!@Hx3U;A9PktN0}%6rUT)eKUh6+T%Ev7S75A^HA4#zNV(d}aAY=~>24{f(1E8ZT zuK;e%uGW*g=8%okkHKJX$ltac^e)m<)Jh{QQg(EadYa*U)L#Xj%uBEyK}5aowiK&ZWtn zca|@ytybrJDthlrn;qM|+O;8o&O5m^_G0KePbMZg?2g`vM6iS8T5On!KNM9_D&b}9 z5%Cs~nfV*_PnvX^^_0|^CMyz;hGaKkK|?u7Mj1>oMA`w*!$KA?7e<6%8V_P8LQg6_ zDjb?!HbY#7X;is9u#*h3etm6mJ5)rQQtf`2oWbjzfk9SgD#a+oNyTKYk0fNqcDX4t zkE^CI>Q%}t%XT@LHU-KWzl*g z%Kqw#g*lt>YVF;xOKVCWg66?A!7(esIj?7W+Exfo-rWR*J#Q(62NcWN*6XQQ(bh7ax~jEDJXm zY(Tycr&nxE>ds70TMTl4KLDjlXji9N?cgU_7qk> zoe!iMXwqF~=lyo~%#fC=f?TrleWQ=J-s{VfH!n46RK{38c>CEqa_0uVnj8k#ieP9n zPp*K=YDF;zI7^F2Kp;}8_|sMzUQEpC4`q8&3YvM+rgTEr9v1TS^~CSS>2ZU?tbvj!)NEfg~E%(~?GcE$s zqN+YAyOe6<$$A+XmHJ-+s}jBjKH5EU{xGT4eB(k+B~rm~fE`-F_9l$(c?^ti<^|kxD0! zSn!BS4Y$gCg0w%Wwg}5G>Y7;yR=hUxn%0sjt*go|?XhGB-?Xz$^gN2tKl+QP+QeP; z>5_T^CE%-{W0q-_-$qig=Wpk%9>+;i&g)C zPgpOjFnJbhkaA>&2|vF+RPss8^o0f$6)(%av#pCzTA!?H^nSH}g=~jCf$x^_Xgukw z2v&r*t14%hMGXY$1jo>W5f#uDIjX5s#D0&EX4mx1ijUvtX_LCrdX|@a&SyyuIP2qm z<5&~O+-ScziC%v{KyXmF^Vl=SyBz;}8-2h*2 zU3E8^#Cf%*sb7Y?aw{=+NuRxeeA&artO#Auw@?r7ql=N5C$s2S_L`+;n@OFK)wD5DvGoNr~(ahTYf8Z7bfeYhwSIVnm7DM>j(Nf*QPCgVr7zF*#^4Oin= zo>O>YDnX~Kd^sR>9 zw-jg(kgN`Cm{I2Z06eCH)wdLr2~?%G&;KX8k$nse0*vf1C+@XCr(Xe^BYVvlKHYFe zz$h5fE`*DRbTwVO0^Gotvd?fKSs?ptTT|&{Fgf4HV9Fz1Wq?=aB>(qLBeziI@2kJ> z{Aq&O=wg(`Qe8V0o&%T`D})}^Mq)$qAFeHIWh5_4BiQhvdpl|d%0%gJ4G7}}59Lm( za3EC};C%UUy%A~K7BSYshGwctv`AV(N1DW4bVp)Bmz?_=8aUj4Sm(bqsgb2#%lw=- zX7KC8Xx9P<0Rm{Ei@Wo)X!fDaCQF3Ri($-dlj|Kxc3m_DCy4hu$q4k}EfPVN$#vCT zwhKdy|7FR`dN6ZjDb3au+(9l6%x&5$|d9jib+tXG@C ze0R%r+C(TnysF!ULit4~fZx*`TGXbylS?pN5OHa8r)r-3Oe$aT1wbTneDn@IWKFhePQ}jBUa`yrW6`lg8<*^(S7H*T=s{tY(XOuYX`=hoMG3 z)VE|!iPuWE+oc_2wiHyf`8%tlSw5D)(+VHKy9~J9nQYwEdPcmh*t5Ww2EDR%Nw4Z# z3Sjrhn))gdU42CZ#gs9j2TDA{|^%Vfu@TC+XjZXn$ z6kUxLj=(7%4Fjf0Ui10kK3+`pA!NY~Phf}Gz|9wF-@YgH_`brIIoUh|wHdgQ5(VrO z4G(0y4Qt8iDT%=3KP$9TF**<1xT;K=@Iyx9<2!& zne{C>$nh2|X2#0?2F*1n1~rPr6xAlGN$u&L-u;DW)h!@RA;^8jY+w>=9yi%Xf;DY0pxWH&3KPb+peDyNFQxIG5y7d@D z7i5jG#Odl2jNWK~vE+WcHr?iQc9i_roG%j8$JRp6VizGI9Qn+1~B9o6dsK~Kl5oVNDk#8_(;U(i3jx#gC9ybZ6h#n9m-VN(&63a5IG zw&ba3>5(gT;cY$8?fi1R5-3Z=N657CGL}C?EbLxG_MQx0LgfvL(ZWnlo)*`dYjXoX z&Zme^JGp+^>WsRPt9DOVhBeslZ*tj(@u2qB zDsVe|8nje^bMW7ywjO&s1_^O8d*$3Z@L+ttzUtJqCp~;B%kKCMVh;&Uv-kq)wr`@k z`N`$Z%R196@h9=S_KP+g`?^0 zwGQ%ph$WvSlxn$y^eX9fRLu0Cwg#lH3B=>|f&)-k0@;!8D7R&S_r~9lj>@Dal)yJ^wws!5P(3x8_DiG+Z}xn!HLG|#H;z|9LtTazq+qoy@Csk82Xunu0v*%#8qegmT3kr;6NWmQuz7WIPW#uulE2RY@guD(pCpSI!D zKDkTqo@b7|;A;V<_KN%U1kfoE=yM_Z7ZW?U^fcOf4H@$ zm!t(_V`*1bO;JB`5WVdE*eZ7g*uv3wyki6I3AIK*9Kgf{Uu|4-yIW|gspooRof|$2 zMaix1&hS>_x2&VpQD2rsdg0`Sv5Gf*u7F1<;*4I$;lf#BJBub9-yMiXX+yTQ3Lfa~ zSWh4E)6wFO(Sok{;;I^aGCO=gPSP|k2p|D6c_4?2Dw4~hqWN<0lepIDfb!Ykj!aLyi7Z$2%x+U5Cn^>a(E2#;4RXkqbY+n=m7DOnDtHd`oN#vf;UJ;KNl3~ zk+n$xMffmwcXvKsLUAymJ4qRd)SOiF@rkK7*LnTZ$*)Gq_TUg~;EirHhiGW|ND zqjY?IZQwMAM^a)W;Cznew$i|fcBZwkL{ko{T6@aX)%!A>(Pqz6Nq7~8HGpSrtdNBV zDYb}@;X*$xF^w48;v7JU$Rj)N2R=>LE;WL)!ht@2ulam5GzAd19f95jSatB}&4YB? z`-aN|V29bWi#l#GpEvCN-R_em)Nvf|N8eRzmpRR6+~Jdl;o+iGh@y6$5Ev|jbUy>= z^#gF>V4HYq7~Z)KKK#C24~d4!1AlTrZk_IRfA2)FN5ZDkumga#!`^}P>!Cz0V;JD) z3LCl0y_ggcT$XsM-S-N}osR}C=Wt+#bHw|j&SDn1=zQS!>hbI z1}dUk=_O@u5+*oRHw*U@re)M1!d6X8RTI`UR@j&HiM<7)zcZXj9xIN`bD7x~(81u9qJ%~zN z6zbyVZVAbGs@*u+St}Ciw8%vzWz=yt$vDT4cmL_Qj+>QeGV=Al(D`j{NkV(z zH+*_>SpcJpUMjktn%IEs4cl24DfIE8jhFxI%KKHYEA5q-yW!oa5Hr3pcGopAdL%&T zs(=kUGa#wAJsVie4akg>mUs8hP)^lldY%V^0e;IZAZlo-d2aIjI^kCTu-@pLGA*D_ zd@=Lva8S3g!ALpAd6uz~K;Emjvd2Lt@7(}RJqJ79y0n6C$*Fk3a^7BCF>evbLfge* zDa&zBuJpE3k4#S&dB}YV-}c^C#1$}6>g6qP_hGFSRx1!|a*db7^?bhDNRpTqzn^hoWuL^jmE@`2Fjwf2aotPXzrRq6|*1$<--C!pBBVmi(z9c zhogms6V)v0O_P#>kFRNj1&;3w5K@0&Ovnl?nw_Q8b{E}}oy?<}8eO*Ws;$e}RZ#m1 z=ez5>cmJ1!Hq@a{l$bKcdGUC@m-RRd2STfJ}9C|&O6}q?qO_xauQ?mUe9h$NMc#s zvw~(w*F41~T#>@QQ7RTi)iG&O{>9U_WlQtg#0w_3RE`1mu{*{(yDPVbImU$^o`j(x zixv2DJ}!&dVMQCK;#$1o^RHgNS@N!BSNEnVVz6&x5`sa4qL4NOf)#wWPtRRf;v|Yk z-BTRv3v!ZTEWht1o$h>{F*q9wupE{y6)6;^>O;{tmuYEuZu33qs270JmaH=;mG*s_ zt`WAWso7;{i;Do$(x;Cu?KL!6UB67bOTs>b1CEpHaEjKKUs=#)HBtLbn!b6tGui4*?4N5$(Dw|fzzaXn3= zy$eWPmq_fL36Rmmg2lKq>us{wYMl}hYn`V8&_aMBQJX8#|1;xGS|s04i{^u}M^CEW zY(LpF%cctiZ-P5rkrRA{=hG5zDZG8zNyHM$a2xFFSWp}tu5>sz$=&x$gv%Qjc zdMvH@l8NW6>{msJ_HGD@U+2q)EJs`aV0Gse;P|9>*B8b288l_WNN7##au2LMKiR_o zud|IXQOE#yI&OyD5)_f<3eQlL^g4{2;+ckXS~V1!51woFH#X1nWkKsXDm!Vf$K6DWybEB4GDt6Z?z#sk-j;@2 z>K3bBf7XaP)zhDKJv{k#?)b>HC*2$~ntSrIxh&*MbfeO^%MUE$-tt)V}*1m)XZPIpz(2g%%_ zt$5@rxJiW;XbqExkGip7sOce+yYUWhC{AB8HENq;&&3QW=?%Kq@rK}S6sH~zCg&v=+iE|8# z-Ff+{R&p$bmnGlRzen8g@iV*QXM2#RKK?z7BGl3ksk#_V-1EsX*d3Uu7Co<)Nq2qq zD)Fecfo(z80`f7Owto=^#oEY;LMD4pP17>5m8KcBBxe-tSndNrc~SZcyja1`s98bk z^RV~kbFPWgQz@~kgy8bfzYrhRq%hq97btg5B-IJJeq2asw$~QH5@<4aEgWo{HGv@a z4=7h%blh1$^PNT{b%Ux}%6w8$L+3WRj-E7XGg=FGRi3)vZMT-@LNE=}`!s$_^^NEf zsr`ftUZU8!JOaV7R4x_mXJ~wLb>N-BJ3a&d>IAS4N)2rOyJbHxMjXqsK z3&yriD^mH`oZSBQ86EpWMX4)%cWQkLpDIZ8LoO|^<~#d-PmA1#(ds;YB0JCJMG-Dj zy`6r%v9IT(&U@^Hh04o3>7P)9ICHh#LT~xVv%6DCP6S9^#^ZcFaD<{rTyGF5lF)1N zTmH^k;+FBpNH6{O2Jgcet@?)WmRr=3-fa*N_6M%kO*PUfh|Pi6ae>k&OgDZ5Zbx14 zf%T0wF&CAnN5ez6(9W=NTrcgn5C=$v1*g(;1X=Oe z%BpyhfzD7>8T!HK4ufRDA?WR)gC?ob{Y$)(fNAHTtf1jXkR^vrjlN%sZrs7c82|n= z!+TvU381*9`ybM>UTmr=fCFX^_hbztcXn6f-km?#Q<1J;vE{hLrMS)}c)RRZz|_a+ z(9qTy@B6G!D%f$B@fFZRpo#4$fozu8t}xn6#`(Z6ZsInDW1rO80|t1W;(Y-UiFa1A z9gW(s_fXow7{eAtG-SE?VdwdZ;6<73Utgt8DViUskMqXC8w46XMl;r8$KLnsSYIv& zHp07tYJ~EKo~*u-^2NId0%?!5ctzaK=gF|$A$NaM;j!=Xs}OsQzSutT;Sy+ip>sfj zJVX`RBx&6xOd{g@qQR~WND^vD@PChBQ`|)I#x>ERllYNJcVx7!}Awmqd0dIKbRmPY)nTy+ih3r8R{0`+06($Yq=Y9gd<2>dSg!**>WpVT`@F_ST;$RK8@+SSO zG{@q;5ib9LD_%Vl6~Vs(X0@*Xtu)UgppO;y16MKQkpb}aeGJ+EN6Pje8G2kI7m#84 z1damJ$X`SHw=ws2gQKu~Uw+Cf-i`A5?JR!(gfdOP3hRt)U1#o-r z35QH>=%3Ejl++5lcZPb(v25$6&ib<$fGbKs0NXc1)U2?YIQo^b5v-rkKH zEkfh*kx{>V=Ip5k#EO>=X%iyxJW35E>>CEGALV^4A96hTbop7{{=yYw>k?zKnvYFi z`pTF+)22w|eU=6VcgS#!QKCilGp458&z0nF+;bPAF^8a-Y@V;M>_|8-&sm;bU$<3z5>V?g zd-&rVqLIpfV4R%juHufHPQ~#@z4E2$oN~H##DW+_T)~Bf-|IU$H*Y0ZaEDRC+_2JD zz_wB5E`zAR*299A7ji4*1nvaxsySDHLW$fL=jS3!&)LE6D_jB$ng#!w2+%Hj_`%1o zZ+2m}ziOJ4jyAiYzm1^5striIPdHhG7Cf{SC?V&N;BT+76G?lQW!OY6qsE3l&5yBGmU z24?kQqvAnsqMePxVY&Wk)MG)ZY4w^F*T*u-v8=)EZl9mCx(#n0FE`7zK?o%FseOoN zdd+#!X|^Ax7tH$Ut1MJLks7|7HqN`xnCww)8v;8f$gM;dBK;Q&L&|p*DYvQ|rn7dJ z#n@E3ds(gJUcKWhVX-5OAF*ZRWi{%uh|QKn=NHUuU^(kXcguDQ#RZ<_m=*{ z!&~jz0v#&N`O)PcOTQ)C&zse29Dzh6R_U2!C6_DNRej_J%YGEJ2CNSYMLp)ryXucvKjcX(VO1$E@xRxxbOgr`shh zOPLw3GYR+^`J71u6*)9IbDkmmA>x3{$r+&%g|J&Ai2e3WseKog5luSjfA4!K(|{cH zo9$sO_@&>XB6rZZ1!)i~Q}xOCP-!3&$FI#DR)SfHH@kXZ5bErcknC|u(dgqZ$MkG> zpdu#@Xh&pYFQT9UDZacPCb4U{8kxCtvE604Ch+nb9xUKzN`S{HqfJZx4=f>u;|I6O zh$4ee5D*!fp>M{XK0pwMZ4UCr+`p`}=viPsilf;yQD{ji-^8;WT?|#VNuVAm>A%G& z6pF9|y++goC1E5`1xYASyNNdRHNUuR$J^wsiVpXg1Lo{yHDS*#u7|B z-L5@YHD>Z^OiEGg^2YcxTO~CKunLwd>V^#moF5e@xVk3paU4GhrSv5X@|3>Hu0nUm-b2*_n92r?1cjDP7XBLx*WOulU5tve9t%diGAG{7!$4KGaMISiKf1w z2;lm$qbQ28pKKb>)H9WHp?#d!9wCH3o0v8Z?EI|Eym!Y zHD$W!Q-o~~JiPT<1;Y}1(uXaEbHApORFmess;K0b8T1GFF%j}9yz(%2(YUd@td6Pc zS`zi4Z@8t~dnC=wvi1ultoZC5=La3Ph7sns&*e}egpKdvl;Y%LMdRg#Pe_=5=F?NE zhj~2Cc^Oh82bm^W8jfkYR%*6Ey|GNifWD{m z#dPlRi*I@2n&iFw2`_wsJCyWQicc(ZT5H3vfbKGq6K52-r!^AHehiN|?}=L?X{sFA zm)K?AnldjEj(@}(@oZG%du`_ctl%J=ARE+$rP%_}NbaL~_G6AAXkjh}$W&U;?b>Wy zbcH_DGVhCh*~=ET>xOny7S7-3jnfi1NJT9yvs`(!V|q&T&x&Q;*56g>BVL|lCMCa^ zb}=ZZZH^o8q1vSCk=CN_Sgu6n#)fu%l8Y|ZdwO2((WswO)Di1D`UU}-8<=;HpKaQN zCt7^&jlGRb+cj9sGFjHic5Y0q5Y&t_FX~BI75?Vaf?{*>9?th1)EG}x- z`WxG)uem9*@J2x7tzi-deB3}Z2$e7!6DSkGU7T;9m$rzuzE zUn6H}Vm%F15)bj}0)7zy2|HlM)))IEL6f*t!n+{Sdpn5b{YcqQ%=)SwvV}2Ck6(JM zTlBD%4QSdzisb$(f1AW)UM0f zf~cT0NYP)}GT9wN6Me$#zPHqs?x>OGIv)1RFw*a};)9)ALm?o~r%i$%wocC#t2Z;9 z1dUy@eRZ=cWh;#=b;FoQU)RLPS&|C4Z9H_iPS z!)KWU&#mMjcoyQ($P1~S$&&T(Vy3#hmH})d+|gnaZ*fjQFLQ`r6YFx z@E4n`5;KhcXw5YzVcXieL*>EmI;m>jN_;wd(s*%?jwJ^%8mZdor|0L59+wTipG(=Z zJ*ZH_W0ODXvRNnOS@QcATf?wajiZM9dQoc=%xh$pb%s-1DORb5-nWdL2TP>6*wb;Y z8bg5GB`g5Etl2EmeFe~X9*^mE&XXynypyJ?|6C^*ML{N`pyfO0X=_FwJ5X@ZFsTtXZNGCeYsK~ zMF5&x_BNbE`lnu2kXmUeou1aiG=aAALPazwqv$O zAWQU@Nk7vwmu>12T{TTxEilAo10qpy@*rv-sl9G&MYvcPL zKIbGg9LCVbC3PSp*kM_ZoUJ0mrMqjF4C$?R%9AabczNbC;pGU-i(dUgEvRNEdTT0b z=thjA<+f$9uszqR>}Vlhhp0*I@`Zf^Vnmzs`Pmy$1uaEf&fKW5MM0p8;R$yGti&xV zC+XCe5L>eYVh*^DaCsNRjI4;lh%eRKHfL&zIf&3Dbk}`X`(-3{m-`_TDUhx#!a`rUf#xJ(VN)^K`zjj*|09qGK7R_~n(ml9*m!YwJCVJF$; z<#}t$@|JLIlI+W%6kPqxP9!dX07lnbNcP(>{E&A1!CJ7tX3ciy;j*b z0%VtCk;6KQs!8~=XRk6#`RC{Nel2G38}o!C1aE^MOw-pv0}pSM2zJjY7}>qX7DkrNIl+B0+ZjoE*nTxvm1*y7Mtaj5&DjxkO*WhU}cv z&HP9BfW;MXE?BkV zQIYQ70h4=^z{0wL+k*7maWXq{y5Vy$82r&lk*lU0s)vYwSWb%%ub!eQW<5uuG}W0U z`>25+k=S2PC09W-Zu>qtXD_IX&*I=Q+kq-Vi5Db-+tQJsJagFHRLyMT^6`UjK<{(; zGn>PbrXyt9fCGQsU}9i%R05#Mi8+?=-bDXU{m@M!*XVZt8m@8{p8}DjsosCG2GxFe zA6JwTHj3MCZY$-HC={#s|Gej}|ZpKqMglu54fT zKNR%?LAZG4e=3im$6*hyxF7pSv#s^tiu>{Y-udwc9D_|v_<)a4g`(wPS^^o8Bqj1$ z>K`MBA0w2w8mf=cj}Y_048=Y6ROZ!rwlzVx=uEpWWUlSg(=c zAXRje`b$^v@9t)kI8F1?mUy#NH$x^=l=dN$?0Gr zeWO}k>iik(pgI9AWOna6ez`d8b^p;3rnDZ-jVyhNuKZN3J4o~Esk~6J)=0Q(Oc%Wt znXrLbOl*U>+9K!?+I|rfZh?<^;=df|jgX28l+1^HZqCe8muw1-{a(WDy%6p>Ra(E# z;n}`~QN1t@_zPtd;=PoaC8tpfJ)g@--z~E~l%kFVs$+5m6Gn60!ul@Asm^9m#u|AxV={T?-E%zUjX~n?P=fE!SRdHe;Lq=V@r_tKtJt+yp)3s; z$|6s;6ez#(&ijdB^2eR=$37h5Wd+)`5Cq! zMTqRJ$Jc075-T0a%+i6@SYm;wJG*>;k1t8xakH_irEIbVo3 zDBxEN-q<||sK+v+zaVg8u809&YkDStF&Pnl+&z{*RLT`AVYEWAXWV#`HLGC`JLsy{ zcelpzkiU+7PFuEsMn8Z5Hg1;ypAK7A?WUGYAEQjlrHI?vm3r_Mz1a6_?$&1Z{G8wr zeMdT{QAdAjyPlUdr4#}d?1p`;%J@3DXDGJHF9qwL7VOONyd)lP0>Oikutkm$!#Jjq_s> zDFf883+da~K`U$@m?1z3slA+^H98UvfZd(^6d4mc?m8*os@|tx73By zJ>7Y|B+w}7@)KT*Vxo8~*z;g< zfo^8WzU%?rI;cPYL#x-$Unh$Yq2lv3kDQtQiUg?AYpp|uwI5<74_Sx4_5PT?r>K5e z4RLn489<2Y#oRKd7(J@0O%?PgdbU;B5JNIFuEXTpww*LaK6M%QJAtG~cbQWo-Nd$> z$#_Td9qck(FT$dI~_sRbSMWgZoa57py)CQV*F^SELY zL-N!3P>m+*gp00zJb!4`-5}n}tQI`XQ)FSczr{7Qe<$-7b2# zUEgERq}wopQxQ}orX)jId2w>vjL|;X#${Af2C0TRrt*tlJlho1Fpk>oE!XN%Ltif6 zX2f`H(@_a(uim_+LTBmIKR}6y!AeVk>;$6kt{fA%#}h`f;}Hxm9zKj>L+&-E`?tWLW#vV7B!=hI;8<#*Z^oEIIv!9D4Gk|FT-b1?T&b)UZkB#Vw6ApX5SbTh)Z}mG50k+$tX) z>>_(R4&797B}h3d#ZuzriSf)2nv&)11QBA)zdC_5k&>}JWlR3gG>58d#-^7sQ9QM? zdU9b-cfjI&5P_$|bCz&ih<6bePOO>ihD;%rTtccV-b~vh5KX>4nwctJ8#_8p^uHG^ z7@$eBuYX7|)D=bF{_3ULEY)Ifm!q1*KJ!*>yM~JJ?8j`Cdu~S+6XGyW3VFw+xjDZ* z+S10Dz;V6~{@VG(@!Dotk##T;+DL3za5a5g>*;FxGgx|UPyA6a_^b!IX#oKZ z{>9~HwpJIaxQvyY*4-3T?n6T1HSM<GMBCjxL$ZiVf29-9Z1U()G>m%qAF6(VLN zHfpQ6!w#rq*eHF#g-aP{C6cAOvBg8WM=l58{K{W9)zc>o+}@jba-~#At7y$k?q z-*cr&xV6^`ceaZahQo3%w;H?M?wCHZ@PtmZZ;ydS%e?X_ygzhhzKyMn2l`HTAbtBL zI?Dm0L<8D$nPJXiF#Y_SyFrG>rl1TILkQ+t*HhdcsbtK`{6@ILYeB78fuCUkCHb;W zDu!#0>-m{ZmoCUoEcr8(34NFrI(_-N!xs1sL9zYxr&5vo_ZZWWQ$s%O>mMqNk#*fv zBp6*|&p1kSnWv03L~7H-@bKr<xL z8uL6((|2#TEFN+HdGJ8gWuo8g!lDOGCr^aj`%xE-d3e3)6Df2aoNdDjua^NqNdhKW z7IB|C(k^+w_4FE*Bo#Rt<4$EM%{Pn?J+=(!?XlD?=r720I$MtI${B+={-2^gHcvTpWmtlvO=(6#nbRF9Ba`+cgTs=drxaW^kMCt_0(@?y6;cd&E&m(1v$lOar^4%v%7g}_#;n=5TTLwdyVzs65lr(9ZThZ&TLt^?AF)IPXWj;as zC;I501shV0yE|Aa6f5Mb<|f|FLk9lp=MXY$t{9rKAvg0H=L**EUy(XcF)EPm!hErf zOpoTfe7$gAlj|C?U;WEB-Rm3$oocETc@-D6=!cN?FdjLg$0aD3&DiaNvr)Z(O%Hy# zy6;(|<8s;Kz9@nds|in**cTT|^Yy0;iA|FZJWY`KBObkbtuzGPhzpk8=B|mX+=jXZ znbBDvtA4w!w{+g#%3Mw2%FCws2BoZ36o8C!uf#mtK`wfHw@aDI64p?SkD@K8@Ea6J zi1(S1odXI4QtT}-@{@tSi-%Fl#|wSS2~SdY;#4?4KcC95RV1Y^9|0*jk*{L3h8to( z(a-GO+ovyIyDyDMgtJzq)+f*NP=)D9f1U61rQeFYNPksUy5Ghz)dE=`qXZ!#J!gPB*&9EP`Tin(`#ITJA$7 z&#B0D<7sLlSqQWmz5TRQGW7IxNjouX0kc#GqNs`9AaDFEFim#QQ%U?DI~L2y(`>2N z{{U3DCe=&Mz{WU-`NEV9iDDk1m4>1dy-+2yB4mNn zWX)^&0Mg$fOR_ER0j-h`9P3booMS?CUrH^$9M9x}%1(|JsbuTIj6KRym1Ob9`GsP%zyA{@l8`&?D>w_|=qjsuMGNp1<> z2OMB6uT)8TsAamXetkmHgP2-PB))v$=-+xBU<;6O6;>6VVv1rOIFy`N%j!-CkdC|i zPGELzNu1~!h$`T%#Hf~LVol~{+tw6jToAGtNgN=<={w0N@qh;a;4xsfu0|y6fzyEY zP%59?$ooxr0=-uSDA)y|SXKma_rgfig20x5ss;AYNS6B`{n{*3}gpk=7YpNdK0xxTVa)e>lChJ@=r1L${ zWsDxZ45|tHftCi>sW8{J{yM%RO+4SWTl9B)_k>L2PoTMEb2Eh6JqxY{Au$10Y?d0e z#GS@yg1|eyD19TyuE1O}#oK$V#NzdcrpAv+7DagqEbfpovy4AZ)@p77)!$7@fjrSl z`dbMHBfjm&2lvuPf&9~+SwAv4s;A05F?){I+Uc@glVj4HV6G(2-B3J}$RA5wAEMS$ zcN#TqdcX9zrYekJuFUSmJ{P^WPqdQ7RP|ova&o!0 zCP1e_TusJxlN#y&0M(!{%#7`3gDDCH5zUf`Z5!|14f7FvYbB`Q-t|6Gg}#cs(X}I3 zh~&%dL#eDxNm2srUO*13ruoepy3wlySi=Mxs-QdKg=6|4?+J{UO3}{gMTlne3f>*- zSC-h+$iAoIgB-f`tluSQE?(PNaudCy!`<^+PdE;6SlX23C{BctC0ZIbRpy8N{0X(; z;7Uk@Zl5rkQvK3yt2=8(ReVKBb|XYgT&NJ`2H!K;1)%KDb#paKa#X3dRG}t&C;Xgm z1rw)@R5kY7XpZZOzZ9wPu3Rn0FAwhQru*=mp4>na23VonSL};O$r0S${R7+`&gAy@ejmwWre$`O0p88rw)|1mb-ud+241^a%RubW`V3} z-y7^R499icdg`cI;xX5`f!y@Z-LfuN{{e6g1%$l>4*0(9yKg>oHle^qL<}Gx-&$lg z9~k$#8M!7-&MxnN^f%Nse^ZW$3sz=fkcw7x(Bz!8zF3w2K1M9i<7-{K3Qp-B3rx3c znu*uky78V|VfK+7G-f|9g|0}L+Irv1+AbXX+ubW*reKn+lRa%*NF1Je?f&68M!Lo< z&xpc!-fi(^;C!H4hsDyijEz$Jl-I}03=gezq9wcj#@c3ix?G!F?)6iE=k;a9JUGb; zS{+Lj{Mc*#lR*M{3Q}dz4O*ns5IT$_5jKqyHFB7pc;x!XJMWL%ZsFW6P7X~uN>&Ew z4f(RWDL{ytMsHobG;UPa$$h8u%6xE=lNn>x*BZY(Gjl;0%|eKg@vzOOagSyxxt5D` zR64pYGJk>Zg3PWEx$B@Ft*>YSOw(@x^A%edu1v0Ls?+(@8WY#{s*hky4=8lQa5Ul% z$`cEJO?Sffo0RX#tz!>%o;%^z8k3yEe7<-w`$VRL&dO-95e;>DJtV-+-{bc%IKj?$b`Bbt(L9a>Q*{^ z-S#gX3c9^vu>0{&P}=>AW$3ALhcsGmPqi$1`q#AYag2rNY}fU^qSkG%(S*Nz3E2Ac zA6vG83@-2R%)%!zKAYy08pGp068z`$+8`gO%n#S+oA z0rDrjzK>$*NK&46TN_V6t~vRTdXzo1J<_{_mZlSc09geMqg%x!6Wo8T4Aiata?$s% zrBA_%`}{eiV?6ZvlSzR>nl5sU)rm-*NH3=awh~7;1LDu4dUkdI=@x>q;Ky^dvQB75 zk7bsQFWPhbWQQtFFxa)nBeR=ohj64;4r3Sp+lS}Jxsv2WX8w>;u~1qoh5sWyp)}+Z z4Q?>mxc~1~J(FdrKin>?kB7tn_ly#I?dA$#o8trjktjUB$&|cpt1l7 zxW|S;xU3KoJ=7^nXWHgaV!E6vXYA*-)TV#32Jo?FGCFtjF4P-1{_t}9*cpY7#KD-s zSD8&&L(z1Ab-TUdWpJzYkswaT`PrN>G*Y>Ea(J)R_xVTa}0%Xr%a49=YCgN6Gfuglr^{p4#G zxHGlyZBe2iCV$HIw(8K#8-?D7Vs&S(ueV<jW-lL7(L4@;Ptcw-|CBtb2jkx7_` z#DH`)R->9_LP!7A0oo6>6=meHQ&c+kEN;zQ=R$G!6`|n3vU@-n(;oE{iVjf`?n(B; z*DE=!NT&38W}%al|1YXMMB0ZZ^4M5_*vGfuw?YQW)jU}9XQ(O=FScU(a?`%C?%XKS zxpf6h!Cp3>4|mmlj42HNy_opiS2QcVX~xtRmA2`|s1&@Z9z&kAxk0O2hW~7Y`gM7Y zGr!W{30qz9>#6DMipcgyfv>>zoBH{iB>3HEFDP0$Slf@#d5MslUKm^8Lp1b(dqBk4 zz057g!N^ER*^Fq2yZ`+0et*C36OUMj@jzUEWrSi5LQHiErM6|NYReY?&Z=b^+H?Hb zH9}sLfX@J@5$11DVdO7@^l+S$1pV}kVhm6-ndaY<7oC#cKX2X?ns?QCrGDHuNpRMH zZj=xGYvO%q-KeLco%m-gYjHsQLM>TEDVej}PIEx&_^Rz6VDJiMJl2AnOc&aa9x3R# zzhUc#ANp-e?6IGr#`JY98?F1>h(zcI3U?!pW>^bm>*IiV@3hKW$RUM+!&$Zb5ms5s zX_Hf(filylM6Jz&ON?1X{mbi69x?SmXO`Hr#q+=f>>hV1gY}r6tT4&;;42A#hAj?W zcH^G{rDrJ=fsiM_%?VQ$m@*iq>($)W1w zrerA`eG%K4h(z{h-~pL*NWsDW6;N$VUS)tKR^mV{%lEAunsVU8Ea&JEU9{gg+4MxP zBt1fr^ms4QGaHw8zL+Z8iD`J;*DL=r8yBBtG3e0FPCz&TwT}gmA;=#k6;@N3{R(gg z)8Ay_f^bw{X;&A=>7$Ibd~SncOZ+|OS??!iS1_Fbd7>tN2l+2S9YwHZ(`1D zY9Vy(wPPpTTNR61TVp?E#r)sQE*=>mNJ~Ucj*kYz8HN}d@Bcg+&vvDj+vgtd)Na0x zvi){uG1}s`Idrr%;X>Fx>Yc4A3=Nr+7-Py0Si{sRD{yzD#YRESrr+1+rf(0$pKj1s zsZcjT$)S*;k8HZpKyU7gImKqoj4^g~<%xg79&k`ue4vBj-r&gPI0dfOn}DS3hLEAy zbA|#Sk0Rn%6hrmolgL;owz;g9&wDKYGAL8;Ab3@;ToK*xLlSo`twn1<3aS+vNm&)rD+qy!D;O+ z{;6~lpi3x^bK#%o1*wBYwP@K@EFqt+eTWdJRG%?L+!GZ*GyzLa`BKjs3B-16z6Zr-Dlt}s2*#0(?y=NgJ{gDV*(~%nVTVoxp2rU9b@F`6;+Y$$U z(tbQxv`zZ>66U-|eKBJV7e}~w%l0huO1DRrZK6dadleaG z++@ge%>3n)9-i|dF!3S#)QV)v4dCC_bCD1Sl0mO9Q9o0iI7UCUdj?LvYMc4vNW1hL zr&8;ltqQWvJ=K=x{oHWa#Ew5J=fljR>__AR!UcqYoKOOzon4HmX2*MNu`QQ=c@Gm} zZH00lB!pZ&7(M(#bIGsnV~eW$Xrtnbx9>Cgi6rUM(> zv*ys6YcVxfZ2D9f!s{f`2Dn0f4MS0|Tu0PS*UD(EpXPW=M_IG{vHc;*(;vkC!-8h8 z$k7d5*09^I(l;LI8k-~n-}g)NQlLH0@)~$OxfsLsK6>!2F!?1dL)-86=MejLCM3o- zMR6sCnW`mrp*aGg~)jz@5BTV$sL7k)c@)6q|<^q=mGYYvy{7%S^=Sh7zBqKF3IcBTN!dlLVuSN^de}GbxN8Hf382O+ZM|e2Y z+0iXnSpvVnVVi7$1x+v^*1;d~bd3UKytyKKX;=JW>h%NP>N^D=(_dL&dOYh6oP>dj z3P;Nb_@Abe8vqVOFaGXZPI1Vh(gVfUjA!8i3mH%sT3*y zu+KZHeorcDPS-7DYmY4~h^$z zYWTC`y}Z7~^#FN;X1tmcMuN8k0%JlF+pKbMO#=3>!c#L9y8;W+@PUxXOI(X_^LEB| zr}k(g*iTtjhgE*`RKf}6VZ4!@|F(h)t!)QLE9uQvvrTaY3(>c`9b-=!6HVTYV!LJ; z;Y}M9wcl3nY_&*R+7=0WcU{C|H5q>Ok8F%-UvMK;`lJ8>SfyEyA)Usu(g*43Dqj$K zd{e%=yr6FH<5H}n)Ff5J!I(X=bthiM^cUMo)vyZd-jOR3zUHfGl4Ydc(@9}j%K;fUr2e|}wHiB%hbZQcUrcMMRf)uNYoo7awNJ%7Sfc$e0Tdx}hfreM`5_EjsYwF0z|*@8%-@5+{M{ zcCy$8eFz7-kkR=QBX}e|JqA21JP$==2-Sjl&kS1PLVCmiHpGv}B*t8XsT0OXAL^(N zn-~(^#!Fi>U)bJ3Z%#7q*^njcpr|v2hqUNjvQ<-QW&e@}!;BQnOrgqEo*y3T5$(|V z=s@7tT=JImr^|{(mnaZf3)a0$;8-4GR7_|>(W z`I!!FvsN7`HQ~tt%Wvvo&^xcelux znNBVI!p!(mzS~Sozv2H;7}Z-_>O^sjBcwb}r&J8stJemVLn_-lj9o&q$e7!v%joVs zJ{q2is8}(*R{Kk{PP8QEU!LSd zXhrd~F(TWamM0%on@lF%;C4{JQkM$-od%A^Y6JRM ztlgf&QXKJUwr~z++e|PQN7DXF5CPbw=lm`C`ixZ#rK@UbrE{Okw)IQ!)tx@AXL@XYmfoOYX~D#WvvT0RVB5bP?@dnZKIjsygh8{#G$=v^WZUwiYolb&9#L59#jd zieDtNw5=LiavXN!Qgin+Z>3)f-g?n{icFf0W%#re?z8puDGcE0sYoR~7T+h>t1xmJ z6pL%D$7*obd~*;p+2VoOO&6V@9B11eBf}+^4o6iNK>ax^ZLFxAXz|4844hA%tk`Rd zOk1P5F2(F!knhjp>nc55<4o0ThT|$f*Kv%JjV@y=F}p3#_3I9~Wf@;nM77)*Hzt}- zmsu=q#=Z=DAK*j9!tAwG?;E8zk+z%aL@dbC7cS4hMr!d5xUPIbtFWd}(iFn%_H;=zV8-{-$PGbwT$PZQbjjVf8=FDy%2; zZo7f6k-$xHf0Yt2DevK^+mhhyCs^zJPm!biBw!1<>Cerxe5=Nd{GuFq zyfGu8J_AdHd=o*D&~1;qI<4>_@jZgDZfj(@x}HRlqD$In9{Y&x(uR{pFDTDv_3TzS zLcm{ZUC}LiL|ls3>wNIi$|#?DfKc3bIv`6%twdAM^M!UpNwG$`X=!T$ng9b6$D?Vk zY@EPo>thzBda8!pT)U<=>Lae`GIkMqfu^qx6}OkULbg?VZuh0D)Qi&~3R;?J6`gC* ztALpk>CYwIn`Mub%S237{YvOUvm692hVtT6XouZ73D<)eg)1AxMeZ>BFx=mmqaP0P zhm7l9J0tHmSyMJOL#)RN?Way+&=h5o()^2GX_y>hxa!G^0P$?vwf$O8DM~_nr#sCl zbqf|}Ua#5M)Xug=qH{N(_jo-Eq;$evqr#!Koa_XgFugxIMq=@TssuV5$Hi04*>BM`9as%yC;CNvxeI zj3L>^1WrG~jXOA!!+bi=I)W3warf)9#aKMttyW%IjPhw55b$6pIXNMR$H0+`u}x|# z;;Uj>O>G65Z0tzlz5g(18Ek)lCH6??kVd?rJ(GK^XeZavk&A-?z)?yEMb|W9)`0(i zc76Xhipm|%YruP32FZ!ILk=t<59GL_BEgZO+`!pYK`{;suE&8~nD7%)!OVrCWJPx8 zi5gPPF9ky{n0O##RLc#ZRb?_$7z5X|Cv|VUBA4cIG_uI(9fFR#h;BaQ2DDyvBauOGxL^ zKS2F*vREd~!}k`t(!UXOdIij=oB}s!9@+FAjx(SSYoUA)a7S_vY`L zSf?IiSJxH3m{XE~o_RiY;&g8cKHC~?bwA<%fl`o}e$sqFQt4%HJnvNokvQa;hpUj`vijJWBvIgZ>(2 zA&u@r`UxB1%L}?o4=>k5kXiR+AC-y%C)y|`*&UVmYIjwoTbEjjEaQ~whw)ho&D=mO z?Sz*u-+gdErsI=*|oz%sfn*EwbQ67+|zoo|lvF57go0!R{CJ}mWa1OewWg00+ zDGL5{YE5o_H8;9ATU(0p%sUwyP46`~W>kudET6DOVt*AAj&}zDEHAN(M+H8x4i~bx zd1U%VTC7ovEmEVc`_={HVtzfGfzG+g##di@o0n?I2)Mc-Ekuwc_SGJV1!wFfF8XT_ z-D@KnUvXczLQO8g2o=`Scq7*M3D!ELB$VYtnlwwELw0AOy|J| znHAnq=grz8iL6|oISYo=K_dNMY(ADH37jLL(BS4V0!F?qdE;3Y3X)0L_H;xDp&D3R z50yqRayB*}Qg%+jhaz56S;QDSO*<vjlm(32ZWB4S0UmW`c=dVdHW?ERPl~OKO1YA!8)w-ipC!VcWM2aXj zwM?BzKVV;fBoKS*aC5TcqyLjRdWQ(&-j8SPZ)c|9s6ec@zA{-8Z%t4 z_JDJP+7bBQT(T2E2Rg5gez93~tL$8|5~C37h`~Elf*;icWg{K&J+ZSPz^|BLO!^_P zRzhFzigxgKeA2Z|8kbNmw~GDy+0q34@1~NQG(8lrX7wa+xMK>y58~w(*X9^0IWQ!P zx{TlRSF-WLv-i&=U0DECOqs%K4&`rfibUo?Z{Lj&f8i>N6+@*RG$4XbmM$y!k+}t? zV2hsQHlayN!NK!n55_ei2V>z+76Pv+~mi} zRhU>N!$ls?%LbgA>n+&^Qkj{iQ|g<@!=hgk0|V$A%`G@u@4VrMYN&xmSxkPeo?byg zWC5qMk&;D?>jJ3t1g45<%W$p{NQHY@5Tr}c!%eihoTGE{ax-K=nqH0DzY@ySJmeeH z3wmgfeZA&Sgyas`l)>J3RPu$nW|E2ayuf2P&u6|l)h`GXokiWoYHwiFb1mt0x%tt3 zF>1P8Jzg2+b{oTdBb%4LhL_yp(<(zA%}=w>XXqTTx4DMofkZl;N)PdeuF}`W3@yK$ zzA83SV3+ctUP$PqB?>~$X~37zL^vB!ZXYnj8VeXI@w_F8Yzi`Om-Gk=w~%v!4GsdB75A3k={4^h&xd)%aK%=N7W zx7U0R;*ljrSbeXi?5z6-P)Q>kdrm%SXo>YzqR&-H(_xc;$>*8_5)}BtC8C|IrqLSS zdUUIsI$)8PJUVX)GYn^9@dk44Q6ply$zlS9kv;doyHKG!6$^o=1c$h@>UyL> zCL9U8rWkYoyXH zxIeAji!9?8Z4+DEszNTRZ0MR{%rYP5MziX|Vq}R?&d8=fW`8cRsZI}iCH)xfzoKZ< zX&dKDeRgjPB4~8>h0)H`2SN88gP$|PJdTM63iW?hn1nF1TDFzlfvbU?F!xhiW+w6V zPlD3e0o5^+wRuRfjk=f{F$@@0d{sck0~WZZdIZG{_(S-6C$+e%w+;d!#`}AnWwP>vjrLwz9>6SNgp2 zOpOgD#LTgQCyfkBrnQ)IvQq{97|gY_8Py)`x%BHJ@71!FmkVBx-^tmh&n-!~Ngt_$ z*qdF>LO@x8bO`I}4YhF6EU}@Nn_wC%CX6qnu*`ED<&D6D^1Xg|IVkvIyX|@}^NTZy zyL6RnU(;nq!U}`SlftkMI}CL;M`9?-`VItAcwGqGbUtC;h#*;7ZYgaTCf{KZ2YR}~ zrZQ(PZcaMy{g4SCTBiNQ*2$3zUyngi z$!~^ryF3OVxrTamb#z3UN&Us8kXnYo`JG2|E)lNp@)O%Lyky*#`e=idNG3C zN{5U$kyYaCEHW^OvxCVdVAxt+SY#U$NCC2Od&Te^JT~O<+ur%>#&@#hK;4c`l2|2s z1m_FELA@k76YeLsyGc5npikmD_<%j^)_;KFjXlPXO7cPafAT%AfEo~MEhINdK|OC< z0^&P$QxM|$k@wRezR$Ec0X>ASTbV`zr1&f1tPkBA?Nqk1d=pMB@*68}7twu(c=5Bl zE8Vb{zQt!KnEVJt*2whg`jEfcq*Wp7c+*A-6i(Z)h>ibk0Jl*?9QRb$Yr@0(8t z2j~FuoTra5r&{uyY3cMSusXAj%9&H9SmBVAk7XN8yH)(E1xmkE_Ov-8HLo=?CGHlL zaH|cKdSxLMhoKX*T0)>R{h*a$#5QOLtX)8$cXIx z2VkskK_{*Jn-eK2ETIRS%UG?1(u zVxhHgxYsQ@?r($)Gao56>jEj)HpcL~(4E`(XI+&`(C&P=7yvvP`FKbOblYu;8ZT)kv%=J-cc>LKT-o;?P4;jD}FTfBGl<@cdp9N$N_`bN!YC^nPc_S{m1=ztKpTQ+U0FA-AC84T``Xx!;sxdejR}>%>YB=dPG5_s ziYwBD($&+|%)HB|J}snEIkH^x!TeD(-CDT&WT~uYs?7*~BR&Qw=;xc@6lLe=CX;B9 zuxpl{i5Bdr81d<-axMR1DgeQ?KueJBwYvJwij$>zC(pI#NKcs6Rt(vas_Nnj)u72+ zx?Rz!BB6IZu3|5ww*37z_Hqrje8R4xK-cDozPNjvz(Jtn3}0?*I)jJYym{^H>o6a` zpJHA7@joW2-<>{=FpbK{??zW)m{2q8WQ*(@vf+qd8-e96o~GWzTK1>ESi^(3S%nG> zccK++SID^b2Yn5!`SE@Xm2inHI_0x*)NFCodU4b?**>HMi?6Y74QFh@x!xzTad~l0 z{|r^tV?GhrfG3HAq168yLS~z6?;c$G9KiW$R5V2y;fb@dQRuW{-$}( zd7XdOIISgh0Cq20Dm=9-vi2=6LxFutM=TB*P;q!+FBzFqrJUE_v*Ak&MFWk-yz&2D zAIV!n$)izIA5&bpdX;4YN}gZ2yt0kM@Y&&-x9VB7_20UBcX9pdteoaIC!GU9UG0t- zj_b>pNL_g2_E(BFMhUZqhu%Wp|G=e!9T#E1laJXj=2)tl_5^pvZ$|ykqHo!m;9L)R z55Kzx;3*r*F#PU-KH$#kRHSDGHyx*RifB9P4nub;Tn`l7BmyBQkXq<=^M8OV!phC_ z*WY#GUlub9p|UOJ{Vt0vf2gOLH15mSN!&1*Jb35;Q{5ju>Z41VVs)EZ&1r~Nz$M7_ z<{i-P^Rt9xAyBF%{=*MugOaI%GDyF5NQ5sr9M+>nG=G%tp=`Lgx^y}}-?#lqUU44y zMH!gn&wr2k+D=J+S?Ys12qpTFiq9XgH~m17?FEzQnR*P`_YvUBcywrGqd3pwA%lJT zg^a46VAe6lR=Xd>gO!Qo^~rCRGd_fL9=po~X3iW~eVi5M{?o{?G0X3gVxGf}?-^)_ z9{2~K_RRH)Y-JnBsOan_?SZ=s4VBms;72_jk3*7TwWT<)+)16})>1&K@Fb z%Z+{nVA`8(5xb6j2$QAp_qtaK_tbtChLUF%BHDlP&^w%ZBrX44I^e2v%;cP`j-fVbbR!e(dh_T{uRqmwSDj=HIl`vO#a3Je*t)2k# z?2q(i(K+NJwQIohk&eZc0eq74?P%^Hf2~~kO$LfHrL<&WX7kWr_}~HkPxwq?^%B6C zB>Pel^mi*Nm=q`t|)NA1`07zS0&g?T(>GKXY`CLGM}(3`NP(BTMU$MdgpJ z;w2@(PypW zspz|(=pl>Bie5;REbg@DiIIji5v`6P6jr=cqh8hdGE+U`i|#yJFvx5J^MD;cREq$E znaF=`n;TZOR?HR#a{!R~?;6&=*a2rH>tYs3`iiPi1Yl8uEiSa&3M zq?o(E2)uXlg`v_9)&7lWs$1U3J5sujpP?VYfY6MvZ!vqQw0 znDC&nxAHszZuVK^MdoEy%e$vcn%^vb&-mSlw@P7LR}tmrDiyM0PP=WL^i%%bu|jeO zbHGsEvo11nqvCtsnbL*WS2%GAV{a8DeYeVdKfKxnC#!@HLP{K61ruO(Ur?=w)u?_; zTguNo-6ZkCUnRZvmc+})%Jh!i@|L7MK1_<3Qis+R7q;-w#4r$X^#m#RCV7Qs&Y-JhYQT*>^zRgRiH4HSiwlEUSVV?z004MZ_-CdWw zjOFw49`-)SYAA-1qKFB%0~wI85A6v0gYrY&QLCsloPqPR?TgHigsXCP&#%zPyS-Wm z+jORI^bE=9uLAXttOiktpv=e8S2V!5k|@I0#p1n&PIaeaACumn>1;%QjIb$YFFu4+ zFcc%6+Dg}lZbb#1hfu~2Dgmx!+%hvD&#!M@ip1OmDdYZ)iGwmeoI`+n?o$1M-Snhb zZv0Y^P8f!xIY@X+^yNdOzytBz4g`amYt+$}oKM)>LP>g3d)K7_EW zUv_&H>LEzIU+i8|-!hGb@ij7NeCaP?Oj)feftcJK{&Y2Po{b=`i(vhySnyHi8)-f} zPWKUFxDG3G-J@|uO#$J9bm=`7x-s@PiQQxc!};F&(|9sjWd8YrxsKuoXwjI+@hMK( z&Ziyo1l!LyF@@5Bwbkn&{1xz!sqy1^7g@xbxgziEya(PFQ$OsHSNzw zXzW;3@&FRLw2n>(oha8YS0`pZwOJR~sxbV`v0!~M3R)C9`+^ni%UPY_m`(`>)r5IJ z?*kJJXw%x)tNFZm_V1^Q@?Q^Pmt)-?*1ymmMBEVkRvU1mipX^(AJlw|PYYwH!Z$-% zSD`+x50O#N%S26(n<(a&u`hIaj%wX#`r1TNn%1}(aZJzP=+EEJp$Ra_6JkGGSlr<4 zdM!z7TH!929^C#Xj@vFV3rcthSY6?%tfE;HmY%w49Q@ zF~}xO+V|!7I0Y6@nMWtN2o&1a(}H+#eDOxcc4n_b^v>2_p>S=e=u7E2Al9@gNsk3U ze=bu+CSLLJ;a3~3J%k=*_B!<9(#vx zB^V=72)ZKC)y{m=%>0JHB@G%3)~<~a-=j|$JAib?d~0JuO07$Q@W9S9nc|xFKZ^q& zhGA#Juw^*@C%|Pn%K62(e)REn*Jqhlwc6hzml`-K@YV1~A-VIt4i2BO6`ULZw=7Y< zWDp(SJ=WRKw&mjQI@{vJLW=I-BTYYZ-WzDZ1(fv70h(Gi?Xjx~w=-#fx5N?*)Nc@0 z;+wgtqBvy!rAu|NNc4q?p!ESvMxtD5i!A69MwxzO2+wuKfLaf5e+B9#V)rW*j^g=* zyHL3bI2$5jbMaBKNokTHIQ%QFu41IVTTFYib$L5(G+N;)n(wV@Cz|L10 z&NbZ0^|lK?TzKL?FsU*9e6rc)ETZRc-r5%4)WgH`0JPdI&~j^Hz>FvC=STf2$W^u`W$_WreC_o1Yy&ph7> z)A|9j8=ZhpK8Nm{cV7hhshHpULZbs@v)1ijXv2C`BzTh&;8HI?+vA<9 zXhjbJYj%HP9x~KJaSqa_cy-7DUb0D3pI{VyxRV#teb*E1v^o8xk?m8;P_$IA)}Xrk zHl9Dqi}IzqRRsO%ds%}%MIkrj-CRre<{2@1e2ob8(G7zkd+xJX(~6h0mq3xGmBQ`B z#9^t?AE?puq=!D(P*ak;f~HY?1RraJ*KMd!MM5#_QI2U?zA7 z7bUQ@EcSdS^<#hmxv$}ss_Qo^jD<_LTR(|Yl`C)h5JRby@LQ~b=3!E3rnfCcmCA5q z%oO|uF~7#?q4revB~n!j^I~q+wX=;y_|p~epsA-TtXBAHbBf7}tl9b9?0a9tW*UhG z%g3L^G!ejz_VRT^KIJo{J#eMl?{kuR#NBGoKM73B-ebXs_u2}uM7x7?ae#gm7^mQ; zPw%m@wd}Q3GFuq&II{t!AoTK6ocBBCum}$sWv0r8lfHA{O$rqE4 z5glMBIO-104w??S#b|PsWwfYT#<5~0(>d}w_M%&Dv263$ahTsjQd9wC`7|)(e^#4R zS)u{Owg4I!m5wzgXakd31x)AvKMnQ#-#Xp3ewh(4J|;oyKggdJ$hg-3z+drazdl!u zNexlqrtapTJq_T5O;J-#loYf5pqeRXcZ+{&^U>!Msq)!nkqz_-xdC&+SSv;UL_x2! z;HV!Y(>}j?b+ga>Z$+v8C!wHGPR_?e{u{WP~JA$sgB;M9IK3nl?G^jeaV zn=i;!Z+kmd3%mNP0+sl%b;=vtoegW_ZcBPP3_bp-dr0_3D!Aw<#(4(S1nHb-cuOpR2Vu{GyJkh@BC=4_ z^;>Myl#imvdbABIr2c-YsSj@r8xFgh`(5n!TqqFq=M^+@$6_;_p47*iMdlx1OoZ_n zJ-ww{gu!hphvT{{)`7Z3LOPsF^`yGYF@O?tp1RK=8>Kad1(6K_LYVf(`~~~n^9M}i zSa6-c9Kz*0n$VvNo)W^x$GbDHd-w&-g|60Kg=@4SFjY!zCf@{=J-O8{j46PY#$YJPtFh|BD5 zJH@LS=vADeSSoSQoLj6_$2!=iwqBqrl&Q3FHMY);zsb}K0-yvZEoWOSHl_R=6^-ab zoTi4?R+P~OmaZcVD+e(;!lVj^od}iLo}A||f!+?brK-?aNUXaCgo5B%D2m_Q z8Sm{O)PVxahqwYhVM;X|bNJGy$>zgU_xw91R7fAP=MlPp%C;lz)*VEX(l9B}<1Ka= z723r3N`~a-oZMf__){(6ovLd6v6Rg3A0;i8GlSdTNb4cUI@%veleuj%F&@SqG?`^R z(RUY)SA5DQCo~Mkk{i)Odci|B&bs(XzY5DY1W!Kg=StVj@<{GEOU~6(tL7W(#57aq zoJ{v z4(~C=0w|`9#J-IbVLCh~-*pq)TcG9MQ?;a8MNS~oMKd3n8Huk@0^2$7R$A0enlSVS z)wfvI$`;i}Yl!rud|ZbD7CO&RhadTP@~l;G>PM0Ww$}q|I|9c+HSFDxpe7t}3an|e zBf96~$ByQ5zC5Px=cMhrU!~$eEZ}=I;moFTIp#h12I?>}OS56bYaMO&w~T8Y?E64g zIMR7!<|HZp=lq7nqsq&661fsn1hmD(Wnp#Q{XX zp+U8cFh%OyvU@Fb7vw3#dJJLwo z0T9b%vB%4w27(*5T@Ud`=30F}l z`QRYcHzbZp*uPeYKKbHME_gG1rnAgZ}uNaMGN0fDvt--SX;{(hQziY z)|g`I?2r4<2;1%k5m>Wxr)1PVJPvHD?H$aucK4@;>`C8sl$n>;85uPTC_`_!-|ca5AvH z<2M|pBbPO8_Z9)d$3dz>qbyEt1&&A6FU%K!QAgRzek6p7)l3SIhyc?61%G+e zLwMK7kmq(y86o?Nx57!1-y4Lu0&KYrEcHPmYfln@rjQ+S;{}=zH!3pk^r0E}L1anRB*O^)BZdf{#gJ0YR zbs-cUiGzBaVOHaY^IILQaMj40Z@<_p>I6n{A~xn6YIg_XxlQ;unCDvfzas2!Wq`k2 zvAHrJq44N(;%8s?+~Yq`cygu6%=L))8;oyxVTpdT?ORpE;i3zb+S=`DRHkp%4~sq!tl0KvBOTMa-@jX&yyFCuC$RUl%i0a{+J6s`vhQeu&*9* z5bd%E9(e4n9eg^<(?1c<07&rIAFZqpMC`vy(Ullgy$x@xs=A$;V2C$=6Y;lbPxGk( zY?{9;-oN=fG74;S5ecp*0DjEv{fvgPxzbCcg-ZP^vu)2E3*Xw{?fs2@xAy1XhUO6? z`Pa&Hh}3-dRsy8ad5DKnwFZ%Y8!;2`zS@}E*)!Mg_SQZ!K~Q(Tqpq@j3*-HK_=?E0 z&TiC;y+@9qIg^I`L27yw8KMmsVoMQ0>&_BFzB|GpXl2ngSFX42kKMfe{H&>wtI2Z| z4_H~2Eggvz-&L_xUOAP|Hqhk@rZpY_U7dA)p0fvguMp-@bVF^>|3GIsn@Mh25C7sI1v#(Aw-FhOIxN8`fz*?Sht{xX*3c zNkDzo1AWjUSBwppc{`#>s1QKYLv%82ev#1VM!u`#yF&WuFb)WBB##{Bn;VPSn`xbk zE)-5A48l6n4GElb`3X$?ew;t6&9-+SA7u}xgSjqAl{S42q<^FV9(c)y{1xY|r6?(h z0NJNPfa$eC^S3d5Hv7v0=f{}YdzW^>YEGCO@%|1*d^4Yy7r7x@;PlEH6i)fJa|CUD zQ0=1vlK!a%DGDc3`s*prb(22pF!T25H2I^JRmkzSa>bSOkU%~5Ub6*z_uskgiTz)g zX3)=WZmpYHWP3X8d;9xMi3RtA5Hh4>ADrhFPuW1X)}He9twZhVW6VpSr_dxAkc9+{(aJpS5V3SJ;Qt_HTplIoaG|~sQP~Nj&dwa9at&ku$%kX z^OQ?4!Ij+IR+=KW;gA=s^Jm=EK0caWV?JVf92FRNvIkjNLcKt9h0W==At11F2aC@7 zF}t?2tH_W=KNT!$uh#$qq)m_a{4oNFQ)gRV1y5B8a_%5s9$ycd4=y^{*q<3VeGmpwXklT3J zz{*r|!J6%zen2JZddVpMtG07(Ir5Ft{bMu?%B&V;<2Q&ubFIaEg7)q2bA$&85Ns^{ z(*Qw^s8)jh7lQIXkapADdEUyES_GWVL=;Dds1m&!gGOa7U9xg1hzw~ABN*z&f%F+| z;i>Hh|4#6?k7^osWad}@IJy#uJtFyDvI`wJX(NNiq}X5w3C&`$=_dUC z<&5m5Wgbq{9{!G3vn+=DV%sKyY20T{5mCHCvN7##hoVjP9PPD=Hj?Mv&7=6XR zCD(!gbA;tiCHV~(%g`;+up8PV=bhQR=n;&9%>+$UXadJ#BZT0nZM}W@c#&{zRk0<7 zVhcY$A)gH+D1C2({V68sx&1G_Di5V~uHrPJ*FSvJ3D+(eTzbq|s6}dCT)X`H_vv%P zlEzD86R~vBLUY74w7H)a^QhL>b7$Ot`B^^qv}%&_WwQizak$h0ZKC2O(3_R_3{eDp zOT!w7P}`cxfH#vH zc60NdUD;Q^&qDlk9N}ccD4eHUDKe>w2ak7u^*`+%WocDl56pf|kCr|Z;ZzZ!V6MfJ zAqbz%>|`lf4(G(;aQYQZ1A$%bHfsBd)MNRSx&99Ns7ESs?%DDnln}`g5?De zRu4L9PSv^6Q60@E9b6Mlhg*G$9)%}ql<0GFDY$c7Q0O0@6RV{x*rk4!RCPdfI>ULk zmWB;z6Ct=i6Xh2?;+?3K`t|cZ^h%EPz4o_lvZ&HwpB-fVipx=-@P+8%?okK{fvKLo zhspK-2cnWd-yb%uUcOJ!GlWk4Zjoh^torGZ!lngvS~g6;7P@DF#qN-Wc5&J0Tleuq zQ>v7dZ|A8=HSpk@ypkXg%zDh3$|J5;J#(PeVkiG0Jhs`XQl#h} zdZ%9JHoG)et8>2@nj@F9Z7YQ}{rj>se!TpXtgU`-fP^Kxp#A&SDZ@F<3n3<2zgs2U z;hV2CK1iD>>}i=pHomD(+#&~8HRyM|BU0C$wzwbxxh__a=%LswG-pnpsYSy3RSN@% zWDvwdo6d{{S&Rd%cE8XE0ZPzGjOdv(?pZ=w*rVEwU=g4nsVRiKVYn>$iC$KeI0nS3 zk6DW`?p^`p=%r4&X+Pt2@QV-rK^~ya!~N9mAl4}?mSY<~j{59t z^c2@~o1SB^B1oTa%2wciSMR@ASD*eLS~Ttd#ghJlb$);(UW49}z_=mDEQv@;qowZ5 zFD|M99<}*^rjcnGdCIh<{?o-%;sDi`>3PAwm-d=T{#eqs#EW@Uxr_|td0Vlh!}=ze zJQ({1<&mQN0sYOBj2CIXw}-`t58u9hOTeir>v9JO;Mm=VWHI@$<((> zLJrryNmL9f9P&l`;3DFL1SYSOs`e!H4|N}VY8lX}A=Un`So<@E1@~839Caf$F6J0> zT+hB~-f+FQQhrSxe37posV2i{@k_<;AkSWmg;_#*U0|GtAqSFbbbUCyBWpNz@_y-$ zZ}Mxc?X)-onVVq%sotFgMsRX>VM1u`W2IK^%g_QfP#xK8vQ#cJO{vw3u-?Q`U|%>-dE6pyQDry-7- z%sdrKl|SHg2}&y_;H2(2s1~H6S{iRQZ@G0wuq+KeLAG?xQh}quI>NayJ!=nlQv7)^DiTvKIQF zB@ywOC2Bm@DsE}EV*xWN$Vxr8!_=8}mURegAd8O-koW%g@_;lvyGT(zfxYvE$3YZe zyUW=gJ)>;UlJjP&B5}04w?vc|1g&(j!mA&T@*6Q5@&x8s0*1?ZGV#lQ(a1j9rb}B= z2mC;j34T7+c#)6UZATjgMi)ffI0Mm7A}%R<1ii(k`}X~wP0cA`lj^yyZ$@>B?9C?5 z)^#oWQ(N~|Mt1|xslv{9i>Ezn$EWSzA^wsOIq|-p<(0WlcRO^IY>bX$R8&OySzlXP z8q!pq1(W5wJ&Vk+k)HyzIXFux>xx7!$GE2?vn~&x2#!uuMaS!_+A_@37vmL$)YB{h z90d4eL6WBQZ2ENbQ5(tnRO~Sa4iWnHx}3_l19pRA=xDwX_mdB$4aC7(6{%HI8ll#} z_cbiPSo(XjrL(mZ8pP>S-$z-jpAOVKQvGeQadtdlb47oK?U6+2{n8aku0bDJfTK;@ zv^9hPFK9+1HcB|E2(F!L9bI_W8f?)r3@@ELrtFX~HZTs|FRg-*qO2Fb3~A)}S{5M* zy3B2gU4-4_(6U9TVFZ_?0)ter=Xif~Ubad+66prv!)3#v1OzfW0H+oaF|Ix9lsVBj zU&Tz!7~>y$f*S3m?x{{|@qKn-)~EVI{;SV_An;Vjc+- z^Y6Pny354ZzYTm1VK`%9$y9#Dwe}m7x+Ho~jWB@KZr7e3ZBbkkJUa~OS8c|VoVZlz zQuxlpuXu1l#q%3I7}J~cq@A6dP5n6ptZKItB=%I$^%Bv1I!5+P!vSLH)!1A2ow z>28~I)D51vA?2l_z6P7GW(vj+Fn^AhkoMH&g-Q#Vw!4HnTaP99``v?%8ov%?f8a1deez9NGYP{rzJfP zGF_}n7tkNK6*-tdd!LFC;-;eN08tDeq)X4^V4;^@r~C@y8^V$nJpxYVaxO0OG$qo0 zGxOk@*l)pT5}xRZCi+)av*|P9ryjKEK_KMUr(Y7z&ai_mKq_y1GP4)u)?fKIOeymc z;wP7(ty`eH?JZu$eH_YpTzUrn2VA6RcbuQu>)#E75v(kc=SZ&YM6}gafMlU5Oc_-2 zysBy9j^wH2t9{BMeDL{F=r{py{)>K7yhTVa5wH?Uq(%y~EkK-jXiq5dQ@4y{dROwY z?Irh1>pbOax$+D~`MV6YPi{s3){j)Z)LZ@?3C0B`^o*yHptpxQK}3ky$M-%5hJS02 zwT(wvOEDKtj=xxY=K|d1$NpK$Qc&?n9CMuUepjxgAOF1^Ar^O_T`E!VdNw@8utsrh zQkbJs$ln2AAWhv*L58}RT~(a>)&vp@B6_za!^LijaHa`MMKy2#mI#3@Df$H07)x+O zTAgxpAkq?m7=jt}^s9?2Xin!p5DpkHGDR8_ECqm>{(i;>9`ZJ!V?&`f*%15IF{7Ue zBrM-Y^hi7E$u02}mwpD|K1Fp=28h9#07&|DR5?Lx2c+tlIPRH+1;da#yIWl z^hyq!d*VHaP|WdAln22KKD}^;C2m=ECpP*!Bl3N^NIP=(-gmuGy*I^nAV#C@kSZCH z8{;@8*Tq5)1G&9Oz^c8fO`dMad^gRAs@H1Bw)N0CZTHaWiahSX+O|gb@s8XbfhCs; zV}EepD4u_v;C8%vS7MQx(KMs?r{=EL%?lQma|ZpT>U^Yp`Px?bPkg_XL4y2)*P-3m z^DrDNx}4`$6l$nXd$eg2@SChI;%|n-% zpfw+CmbY!bKO=S1Q*9D3itPIoiz24T;+1~Kf-c$2%#r$oFG5Fcc)a1Jwrp~D6U(yp z5s^5Rj5xrtg{qRQq5>?tA5=Jzfa(;oi}-q&-?)&__4`$#E-bycr);ai?K7^$9i42a zxLFs!lQ$m}J;d*cvGtK35hvR`pv@DWUxU?%ki~&bIhxzh z$uwDVHcPyvUvi5rnfNNuoNrV1KM_VdO0`%>*}oE6ge;E&dQI&V4v>m;6N>Ad_6n2b zvGwQ>6~%*1Fd=X;1F-aoVcgaW0rcjjK6QhZ&zWbLb-Cr~ZX(Qh-Q&)zj@G^|g&~@I zyZB-FRV!EtNC8Ok1|vkWrqXUkUS4E+3T@aSE}_%@Hugh%bG}N<`Kov3+scoI=)r5X zDjk&0u}HKnNYskRUnzPYrO!*8t^a{!7dRI^jD43|h{Ns}NG}EJWp8p+Vp@z9{Hu^R z0GrP_9}|fV6S_EEh97L5$L{1~_+$@eEVWgi7yr`yy}rwMIqvlJhtsT32!*zgozD0dm^neapX zWlM60k8|h;J&qMKNJxZ8-}sAVD9iCSgi#7BU`j999tIPMr$NV?UzNQhJDq^ zN<@L@_CP}nu4SWk3%*13inuS4X$ng?SDzB&Hz+o5argeGc%ZYAsN%kvP{Cemwr-dm zS;s7^nGgrKpq_V|3s!q?Eg^~hhiT9eKVXkLKn<@stZ1h;oadZj#);=kW9K5LRit{% zvnIIwW#8ZarjH39jfN+j-qr$P{E$3I?;Ga){^rewyHg|g!p~dOw7sL1sy|0}6r7k- zgzy;WmX+StDPScLL-%v)Q(2YwSx@FC`Aw=d(hY?~DtNm2L_3rO^^q?ZGj!NG6IN&EL!QZ=b3oYpTCYZGSXi(_?SWgsLA_rVr_ZymUJQbel#4y z{avK6?r^QD0HFfie;lK&JAWGFHJe+c(ZKqPb5;>HuPs_N(!_P~i^@bL+1b7n$C0~C z@M4*sWZlcF7-k#>z6pENzK+jX_=O??-GbQbAm8NbLDPMWe!@y7wPzU#4%cfB$A|q% zlHLj(p|eOF(9)e*eQCy$;+34_3yeAHhp+Ey)D}BrN_?zk+ThQ-eW|NU(g--O#{OL< z%j%-gk(fbVs3mZtRSC3~&=67hfkd^*DpY#5hSo-H$O>|s9pP|It~Lw&V^B{O4xXzg z-Qg(m#R+}cXznwoTaxsaAF4zzsrZY~k_dQ~&SMZfL&kz_@S28|xIgH`rVd-l-#BR1 zQb4;m<|uSp)s(#{N6jn=!+zpbm&l0+@Ga9*V}SiOgaWvVU`ZHowsnP99lzxeEus5d zs>I>)CI8;aaJSSUh3`&veQqQ;LnFiw%s>T`bmKE&g7k`kEeTG?vAb=j|((+=THBF(80QRCUBJVWWV2%Rr6dgR@bt5&@o#=)z;T8 zp=et;F1yE==MUU6H^&>#%D_Xt)o44Zh?SvWWirHl+B}X{7tshi*uEHsL}r2sk@OJj zQoGj=8tZm2-9LIM(g8E{mc0r^9?8&5E2DJzuONn=>sxzYa>9H*^>A6A;abQQ;$gGL znG!+UFXkK_APntP-yd5t{GGtm67o9LW|f`y4Qb3DVvB&{fKthCcOo3Z1uUG|RNFUp z9aC3!66vp;Daw!RTZXvLCvC3VAd!gHejg-bnQM)7CFk_>qyCH(PwD`PiW2R#!{c8? zlB?kGXlXfJ*QII?YQCuj=34XO^4kkg=Ke5TFd@rV> z6o=mff{fh%Pmq~L5}%;o62qkC55)3Ox!fn~HMuiN)*_P6t`It5z?jIR0x<g#__`GL~o(Y@cZs zWr$3t>B1B6yyuLM(rV;99j~6DsrcuQJl-KI`ETCxi^j!|`PE7+hM!%(=y6f#c}Q6E z)8;Pn38%QD0u(0I`xB!EhIbR}LRMp9gHQu;4#;a>P0wqgG-CxB$HxiLHNW|#tniK! z@2`JYay*YV8yR`5Vq3e^>VJu$&yoL+WV7@uhoOS^F0U-az`oJbs3+y5Rr+34suAe| z>HU=T{?RMItl7(o5BX&&;(*-zKM)Ddy#&7fBTIL>S?mS+^qWXhBrv7+Y6~hz0Nx%< zdIYq1ETNigP}1Sli?PgpS=aNRfI84Vl|L;H^P#8+_tmd}rz;06bP!5NT=(sfLitk{ z^|Z@O;CJB*(2G)E>a_EoJFMBt-Lv6Nrhp=ApY}gTeFm3_9Fdt_UMLuYyUbU7;6)&g}Otf=n$@c1Ji{kAAI;_F@%X=d^vFn zBstgjl?OfWqhbGa5D9)mhqX~H$nE4y2F$*v!jE9+xP!Bgs@S|eZN$G^UO z30>2f_1dHWTSD*yJ}-6&1BWZ2r+Br=mpQ3vp_rp9>|8aoz9CA|ep;sIdbhhdNn-x@ z+<`hVmtBTSw2JAvI-&0QbwDA)s1S!O*}J{*Up2RZwaDb0JyPVW9mEboj(uRMKm$vg zs6=7MUj1RZsrcw>LyJL?o$A&rrg64$)np358eJ@3Kvhq-cdh_9)Fz%#BAy*Hm>=8I zfGo^#u%#@6`76eslwv2qww&-YX4iq|!eczW<9Dk~v2j8q=@T|;wHuS^sT5+(gK5_# z38Ru?nfnRb?p)eW5hTD!m9Y!B#@70~D?eCBE{!9ik4$U>09-2LmxuzU>1YqqC{BaosZ#U5 zrW7r?&waE-htVPxpKuWxaEl^h(qWrVEbe$`v~#moCBdU>kyODLw)_0Kg&)fN^DVM?x4asijFics0l!YKFM> zPIsz@TGo}zqv>EiD-NX?X0VZpJOThWy&r+a;%ALveI=v~BCz;}XFU6yRjDb4sQ>sE z_l~cjozMN;o2i#bVFqNc;e q|?%}&;-SN*uXw4&AretLW z4kdZrPlDeVl#?RwHmX`p38zOxu&<$W6qVjqs;v{p0P9|YuVTHSaKmOgyJOVJC*42J zO(1nCH0Ey#Yr|p6Mj!QzOELx8)8Cte&|&3Axc0e~#k~O|o4CRUx1Qa13eZH8_eVp` z{mv12Vu)4flfZU&KjfuSmHo=z*y>T6XUuC~0}4^b@Q+&Rr>7s@R<`qMnF&QU)H11v zH5VEg1is`W&ygY^koe?V_5o*1Jynq}M#r(aP{D3>hHE?z17Qtk*=ee0i41z>gJl9| zE}*VAoV49OWA+~p4;EukRA;inrdUuS3O4c`C4HcPWukC29tDENX~{ZxcJf!QZbuP; z0P4TcPAszj$(QT&Ya6rh$F~X>q`#UKmM#R*tO)DGKvP8pu*?W4Do8IS#C+03A1*u~ z=|&K{`AUKkFCp_OA9LVc+hBv6gZu}MzSk?CKS?ZQSnGu`6=~$NqTB+k5kTm{`Zf(d z4=BPY5uf8LlO>uVagX*qjClHJG+ZjRk1@&I-h7%`n7KK7xzP_trgaIgiatYa{T&SU zG8w$$B@Gk}>)KNdYcmuoeVvheRM2>&!_ZDaf6G$v$o3g))kk{2e@#&GG0ngd@egS% zXLJ>FS^4~Jwzb}est8sH@)a$xQECwGe>7Px+H&rd#Om9?faS+=J74>E3QgFShY1X0 z%ot$P8GKCUX8r-S^x|Q+DA(D%Qte--LKF?Is78ia;K^w>*ZcB=)EFw9u7`!f1*KQF zU-;U2%j&hhPaJj;ec~uZ$W+P|1bl{_H!)Zm1UgO0yh)LrVW)y@(mm>o?BD0{2vu^b z(3f)P=`p4lFlLN0+~E^!zf26__p#Jdrm|(;T`OT5(@5<&U{h#jQ#gL`P~c#=! z7JQQ8;_f4&dfjyLO05v{MOlHCGuFcwHGRlGMC`XaYHxcsjT_j1tNjrZpkItn5dbKzhc z$I9gQoaHzuw>-;eZ6!QRN0yhd2g|gN%01`i+;1MGi`G=mU^{BwLwf+Pu_`QTRlNIq z(naNCtu_NJ_rmf}I~(Gn4g_g=R*G!wz0pIUmDXc_9irU{I&(45;1pGmIc>NX7$-nA z5OA^=c3!6#tn}=#?xM46rFjmi2qd&n;V%rY@`cQQINNWFd(GAFo;U}qZP+Q_bKf?s zJ?2XZKQU5SI(S+VX!EM~x9NO6et*$dVW=fHn82AxsFHTl7~<;TRB&pt*qNrE@>zus zag)+DxpJ}9jo{G<8!CAe^6?+EDMc|E^M7)40Ysq*>}ivZ*tx{nZl<|7Cq|m$rHc!I zeZW(C<_p78LvWez#2V4qub${_EjV#eS~l~JQN_pfr0BSiGo;{%Oa_2G{{uaLIAOx@ zGHI0jk(~C7JAmPB83XNli+Mt4{j}ocWn4dw%<}?uZQJ_0wub#Y?ZxT<7E!0g=E|PG&Kas8HdtCp^+$B(k^3K-G1rh$(`pa~jwNve}{E-`x0HzF@ z%6kr1Z{;fUkx|Z_lH*+blL0nb-f#DhE{)U)SDyIet5V0|aWk!=`X9)+0}F*?yK@*| z{)mgfKW!zjE>$fZ5w?E5v9CMtJv^Q`)fLg(lU(>%d6MK$+t)?xa_ebHHq$+EU3{N= zE)r%RUxJO_LiLqom1gtFt{>>BP;i=|RvdhI5ZQh6R_vCX@B$NtXhZkv*UPWJ*F zexau*q}o;@F_M2yh&+fY;9aXK=CKZaad&;n%FGJGn==9;T8qXtHMkVH80Njf)g}sO z?$Oi+To)OK1kAfpD6~1xZcOBd*IrVbsp_1c(ua_z0K+0t#UE}$4BYEn5gfDIyg+mKY1z*{e(`>|Z1juKQdnRDSaBs2ERu zY6ov+>*FkRZ&!fLvc+l1e5?l_mLtFf2z0p=gzdrjTLrq5gJt=0Bj}-3OKI;-UnzKbj}4C0{LdMQ?0x^Y*{;caQU zHxT?zUdkFX!j0EGbfovTLW*nFd@(>6QmDXjKijjRO%en}ic+_N2z~|fU3ygQ2O-iM z!HWNZ%%)>(;c{CUxm@epnP28KkJ_SGa^b-hn=sZ}+L{Z+8-u8#5N&$bRTc~jcNFP8 zu&&1Fd@9WBL-B_8)6mb4t=c_$JpKm~?YVR!%_>L>u89W(}|$Y*2{MSnvZvy|Tg zmtO7agvJ!Z5YeB`{yuO@nmE3fY9~~PVBa42`YlC{Aq!@*+I^eXgizf+l6r#>*TLQS zY`k<=&+zhSv5MCMyzYe&e%g)02!6gFH-?zvUF z%egfvo8Ij{qlVWGbg{s>&i1RTgcko)DOxRXrvo;SJDtZOTlBBnnrKF?#?N)zIwoxr z9E^NTx5x6T_Pf{+HyvsoE6hJQx^f-%^g^GpxTu&&s)&;;rEauO|MS1~okNKRsA|I* ze;%$7Dq``MyWI zz0S_m4PT;Wh0wHx-e{r)V@ZRKVlv!+N7zb_GJ?Ky43i@(V}B!_V0J?b49-!Dpmc35 z1&ZHdalr)tfebp?gJ!@OHM`D;mUnRkta^-7b2(IwTeOPE?BnkT8v4%XRvK}wo%5be zR@+(=Am%mDa1viap!j0p+M4Q1cSaxjcWntO1W-f3q-GNK5R%b^h%DhBY3*%WopPNz zrFwu*dS*#=7aV8Is9IO97(Hu`Zf?^VImTvJ3GC$1k!_&)W4}A1s4AQ2_LPfaAFfcs z>T{;CU0nH<4e^VZxzW3;u9OWPBux0C84s8HBm1s(WR!unD!-ogjJ6nQGD{E3D4`zn ziB*}?R=Q0`OE~1fUa{~zcZ(bQur(UKh9-pYwtXeLjhd|rZgve(0Jhaj#{YpH8Gx3} zI-i8ip2d;mf3^>XnYWbT?cdKiYs~XF=`1B8_%k(v$IPAN*~xD5NG1~wOTv}r-B{5;XH$W$io>AfqxU&`HOKeXrs_(2elm6JR3CrH$NZtbA!cRd~#WT@Z5<2-5^;( zjMm9=AegH>_Ql$R`foHp{Mc0A4jB@!>8;Q=T*(_6jXU zJPt*At;ruX7i0Yru5v4>iX!&4jFPCRDyhQ}eel0~VvQWAImTKsx58q_a)v$Vu$9_> zSX}$WHGP3i)pC3CIaZ_9Y=~HH^2$kVEE$hqL z0|6BN4}?2eD0r)kjE1?nX_+iW0<-9t#>NoGB0;#pJm2KnGz)g+$8lV@ci64zwN?7K zK0R5daBRy;K3z=;JE$SO20~8I9XeS4+3O9wcA#s;r$`}dLB|EO_EF+6nx%BjllEe# z%b_|kGLLUQhz4;X3tvp7H^`;c2)U##rQjQ(rIm>b5}J7bk4JG**zGv2orJ+TOrW`H4L;F3{^Y zP#jX4OrN#(gTBAQ(Sb%)E%n4CdWZE5>!&+W2$IL>PDd4ZG_TPm4qR;imY4a-giich?{Qmg{GwbIy7CNd^k>1<83;ixk0DlI94vBfk2XHaE zZrawRzA^kj&8S{NArI~wn<&K~$|I*DdZ5Mg`4-V*dbs5tM=J4|-HDwTsHd2Vlssmd zB%AhW9)CR>+tj2$JG}QRh{epy(`s*PA}Cj~)T^4v z(tTQ1a1NxY8%c6^8}Sq-0kq~}I}ff?E}oC=Y=~Z85Fa`-E|R%+Nuw>37P*WatAElS zBW_Qt6h5_%hXCZ4xD7#ASHU)0LpA>5rdqST?>VQ#?$t+S-ecjhV_|O1+~Q!L-d^g= z{{KMnD9WdjE(%1^R_z&ExsNuvCmxaG&p`b_ZH_vv^u1&z2wSm&MESFtsrT&9*J_t{ zTc)sw#vg?-qsz}qCRwL!WwQ}0CALay4VvnysVb_aeZF&m2;t8KCF;>zJratFvj3$uLCS=RrBg+7>H^*6-;aVOj0t{=Z8(5o&z1UaEY{ z8|Iji=}!>D)$mC1GMk~nmo95j-s-X_Ln%E)A*@ixm4WZ;usiA{r1yOn;mIcEe;|j= zXerWpDRLu(s!|KQ>yP)gk(eWC)z?%7N7+5DlE%2g@+S6hbDQDJu)drZw*rQ-SW0MT zwu@RRnT)6DMN;yUHI{5_dO(xx{9`vB(sTgcOr$mbN7TKy+w$yCbl%C+*59r{vqVwz zXe`?)H7l|Ed{lg85*3A-$cJgDM&y$fHNfr$Lmjs}-9D5%SIQ#hY7G$bPQOK=^eB_t zNX7A|YGCc3`vtLh?hLH5LDgIijcl=#Wj26>*UpJ^#kE>k@3|7BuJWFu7g_aJ;&Ua! zY1i^4NshQ{ojRGrhHJ-S99FKQ1g7Ke(;<;IW7)UR2Q>V-%VlSRQ-!;Y> zlT!IdZNF`2dGlN2pufFl;_2YzmHpgaRpYs99vuu-x$P)>@M)b=y?Q*$pbzGhlL=j$-(Xz2$tOS2>0(-RF3!99^5a*Y?Q^{! z*(_LjkF69)cP;I+G0x(8s(^te5Y?RoP3TgS*wTlztbbP}7S@fi5vdbbZ`l?4bN1jl zR9CB|v)!^_vq{=p_NVY}bJa)_n3@TI1HnV8R{OLY!nshnrsXzwK^7V63b&713F6i* zAq1?wj%-W=DbC+>pf;*MAeKDK47b5>o#Dg?%DWhd6#4W(FLfcGQ2PhB!RQZRyE~ncC|t)!!+_85;WNAn)bK;}938AS)RTitc8o*Zt=YR- zz}97iMvv;Zl*x({sWxw@eeU*%dbe%zoG-VD!Mj@8Ag&K7X*?I*zN$v8nnqcwfz)>> zUpNODP&8bG*7L*Bc}-Jez??;iXarS_JCjFU#=LjNqD>ZStN$*=wwX4w+lD8>m=}0O z1y43#by!y)d!uVDaQ~X+6p~a_+B}VwYt67ZHS+}x@J1VHvR7QntF@E%^-W_==iBK^fr4vmRZEyU2Y5%gnUek}vuZ{Y+ z&-*y|S(5MdASb8!r*=&`))7l(s$h8sz(gC)NywS5s!?abA z<~%Q6LCbd;EHA|pT2Z;JHQIC9;6$iu>pEPJ-g>C{Tzqq>oDH@=aoe@PmeM33T2KDHJr zD%7(^=LyPgER zPF$F!QBH*4CM(nF+(V9*;3&_DW6JQFTuY`iWGn{n(5w6QM16-{jU6X0UgvC zoA+?qTn6|CQ9ROq$D{jT%IMxKEWOu_AEf^7Dn(Ah-%s0LBFO2Sx7fo;8gO(n?8u#Y za~F6@CBrrtLpJy|?YxsLcMmFLgARL+-2%gXySm{d!)^1551Fxz6bh=}afg(533V}h zsIm|Vekllo!zy8XGrLcUMSI;UrT_RP)F!+EVop*boC+`x9A}i4+KP3CwY@8M#9PRn zw-P;b`l=Hn?H3`7@y_*coX%(-1$;Qx+p50?~=ROguez7m>7hQnX0+znj2M7G3~@ zY2lIk=4XHA{9l80m?_H67Sb-?Vs^uOrd&jYW-G{oiRg`fW>B0AmzLgN7}EBG)70ROy-sX;Yly?zi-3yiY{wzmv>)ODLOEEGF(6uPr{vTYBoxzm-^=W=J(2~ z&2Y)N-)6bHtcDenibwbCe)S`863DU)A!o?lV{v_qQ&|7jL+Uu7(cA25jIVOa3k83% zd!ryRy26oWrRBgD#Lu^sFb~DpP4Mq3s(w3v(A{jnzq5T#N|fA}w2bDWKn+Kpbiv*1 zwZZrorf`)K%)xwpMR~ac)nP9GqV#dVbOMPj1a1n=32K`9iaaj}SL#<}HSH{Q7m;Hib54z8} zr7Y#?EO2nJw4ESl21KQ-ck*LrdMi9^Ox@W3`PEfzQ+|(W%qohGXLr186U_0*R?ier zgnJuy|A4^Mz{_uWLW1~Wcc#Ay&PVoj(I-V$^#l!AKO z+*cdJZ+p2uX^CH8?^|q0akk%_G=fpysx?PIm7H{3-NrAyqLn>;kh1NPsof_>-=+SK zdHeYnH z#0#8_P_Igw5c4xK5k(7VzNa)g=g{&DCC!9{dNwmC$mv ztn4o|GNwI!{4unf6!JFJiL@8&onuXYW-QBq^|CJ|{dDGG@eQd49*d*e4Z)eh<0qf? z+i>{caJ0i&Kor;A`=7c-N?}c5(0$jD4HxQ8*XEz3d;fv(yDctAkslU{+UGnzD-C!H zs3;%VI}A~;)_D{CB|7fEuB`)`)h@el2@j|~Uk&&KS1I6$s@e8Znk2R$GB~>8c1fu4 zR=T_-PbxM~5PA5Blt+-lo*>&)u|q-Rv%uFKi|vW)6Bgj|g5#w*@}`FNO#bU;U19D_ z2d|bZxFfqxN*G}cXmkUJkYknwv%g6Gs@VPE{l>{&^q*gg$g$#H5BuHn0P;|;NC&P} zA#@-}YSC%g!1a&0w77b5;ymBF30UDyFb z-moNE(Fw4N_}(1eU;RpzejZ|~x_l0G4V9Tc>-{>$bz`gYTtp;r<0(+G@;hLA-pZE( zg$elB59^=jvwr=X@-|&2;r+-9S@+b?DJ*LHLp3oDhBl}-M>_);Rbep1{jrMw#jZKU z)Bh4RC%R@8UCq4ZVLx}%s{%H$PFEi_5ij_7f>B&@75$W3`X+nxUkG*choG>z(9eL;QITA4ED_p5ESOdQ4yBk5U=Z>{`P`GPByhzh$}nsK}$m*CuSmTi6w%qMh92*?tmaDry?aXn0nZE2WN zxO6mV?en%)@6}g+n0-AR4&XVE?UnVWCl?{TC5w4;w*it>=#I&0`wkurwe0@tr_3R#7j^#5l7+%wI=t1S zi-~DC>gt|#ghGb(A2fg zQq-3x5YrNruCM6HF`LQJzTaUv$?fPlCw~s%H@z~ZfXNIhJCRXHND**MlvP~M+deX> zG?L1OmXH#xHwqqTt>CO5-?Tr35#rDtB+Ls_I2Ij?FG!N;&UMHY>7TS|&1B4dX-wEz zGz#|~`ab|eLA<`|8CwJbG6*;$KDA?chYG}zfs>Ba@$!VBCmAhQL;8A+T3Cuvt4c~y zl1()3ntG<1(WbRbl0Bh(q0|iWDcEcS)9Fwlib5q4k^^y%T9BQ?DJC)30;bHd;PnMg zGt^R(!EECnr8QgS%M2fS5Q15G;k(iS-Fjm<>xy2z2TEup4cre-J5qu(jN=}(0H>!X z0Ck{y4kP{L=hZ{{WJs;#~g#(0=p(07tzPd{_SflBMEb`w!lK==Y+4KV`q* zpSt|k{{XOmgkw?{W&WsviS7NPc5AfOJWVjSWRz|LA91-p#8-{}!X6|@KVjbvBT0c4 zzp6%XdVi-~9&ufMr*vm;wKFqh92Mz|=O({H$?#Eh6KL?URhp?|Uh~C52J-Hll_VX~ z4E|@+j`f|Xc!y59{p;<({_5u(cKs_N?kkJq=bjXd?NN-MPkx4~Ty0n z72AyYXPkA%KhnAlD%M>_AQNB~yB z01SKQHRac+md8vL)!A;=H?y};&}5Pa&ClVQ{6hZ#gS~zeT}!Wg8~valw?0hTey8Ev z*#=G?ITGp+Pq>E`5BlUM3N6nDk>*)`Hnzx944Svgh zAo!JS;dtTHWYwUS+4U)7_~ zUrqko`d`Hl5cn=H6nIy~ma_OG?AEV+XKgy>XEYCY9Ej?N3+zt9~}wc+S_tnr@Hbo0Yy)dux`kxIk5uwom#&=mF{J zU2lNCH~d`Iekti54e>sep#K18YAtXixv;jFNH#+wfD1}MIYFF*k(_$hhI}pYzlARB zZ0)>Raph{dOiY&dlC~09zlS*EAP_KpDs3m?liT>WR@AgF6pKlt23XJHDW@peuBz*E)#5LsuZ7u<4r4I6W>-R% zCZ#Hor8cE#RFoQo-J5>wntCU;%&Fqv1AJKcdGU8$@wb8e4XtT5+IF{jHl2NEqRO`L z+pMx+f)iki%N_ngha8S8p8o)Xv;M^&8$WNKfd2pqb^ic{J`Fx1*R*{)Ym4HVou!dq zwn1plN&R# zaPD656s2BxOnjWD^aSPIb7_YE0pik-u(H!{{Vs)f5A|ooX2VVH3`!h5)T>ZtN38v)mT5^ zi@)$xog2jVnoXzd!+jm>_OsmE&ku?;qY5NV5;%F*?8O_CGHs36Jb*}1SVbXSPaA&K zf4298JWGG!eNXnX_?WWk_maUj-ydo#jje!l=kcDr_CA&DJ{JD~g3Nx=z7+k6^`F_o z*GIh8JYjWZeP^cpOx1i@c-n=WH!e#^N!=`$Q3+)b$g?1e7XJVvc|=L{_$vGri)~bO zpsjSf#w~o(Q(RACJi@EacCDW_JHJ$$UZap zgYh56zZkwGUi?}3v*Qb$V_wzn55hG9)5Cv6GpbW`ioW}40 z%!~>ZVTH~(8SnlgppM$&85GQn0dcV74x_$%ar};Z*NvHHnXYYB%N1HuP2Sghsrg3s zU9>*`0E7HG;En>q!j?&ar$$kNgQ%e&YdGlt0C`f}^Sf7k&eye)vPxAUTgh2VD$1t} zp^qTv1Gw+}=@uK9oGf7A$Ij4j!-JEKN%zU?(w}Ov&vfu$1!g6J0GZE2$lwkS_)}IX zmPRHZgCn>E;16Ei`&W}C*!yIzS$B#yR!fx`AOvB!=Z*;Ap5K{0DIV;HStmzLyvKlZ zCP#n%x8!QFG989UPV9W4#t0Z0#xutQ@u^EJlBQloLKH|@i6fKGBigWPF=#eDBla}? zpk)1>KWJ|c{7JpJl07iq3mtwZ#s!SuB%e~b5)XLAe$(lIv_={5WwciC8 zwE5J&>1MMdHu5*76c~y1kDj%|kEP2f)sI~t=yp+(sZGW1wmy~6?xQVnYEJa?+inI) z>7S=3^!nD5z>W#t%oWtH1UFw_&arg+OO)5=KWOsCNs=@CV>kx@{zg*@o__j!SIb|vPsHopD^9TZd8S;vZG0Mg z)kfvwQ-wqIY@B;~abH>K)A)bD+8vg%Z04Ho^2c$L%=^a*#2lO`$pe5+2Y&wm1%BA~ z(~peT5M0Ipj>g>El_!KN$o4z)Us=U?RhHy<47R06%B(%4=A(DYs;W|gZKbPQ+Wx2H zo&;kh#y%$J{7pvZw^(VJU{{RIV@ul{M@Z&=8Hl1+s$E{z>adOF7Kw%b8Vg)!*|!QoaRzXz(A!y$EZ6 z4E$B+{gyRpEkf+NwQ-DZU%Z1LVSy{h2EL#Dp8o)8yNz1%9|QbKz725XErzcnVP?kx zB|QV4Fx?JtM|%7-_{`z<4EIH=}y>jeA(F=pPvj7NEc>Ft#wdTRLe8x8iR|AQ3 z=C;`A{u@s3;eQg#AOS^)KQABh&RqWhO5~y@EjoLdpbMXz7CdAQr>_~V%i-;^#o=!e zM}hNL4aIPC@=08M+$;T_+GWauti}M$e()gi?oNMN`+VKshjN$w%D?)0XW~2tlJP@} z-~B(q{{ZR@5=ETcK&}S$Xot_YIXKQmesTW*!B4&)$>9G0+7jQ!(5!IV_^U-rhF>cf zxQK{fob(_@$Nl>K2)4MI*Hn#C*q_{Oa858o5%~2L@=yE~*GGvw8{n;A!x~Q(t9NW= zQ-T$uM<0gd72;tkN;rhOpRQMp;+f)~u($2i#}=hE9Y*`X7y6jE@jr#_S|Vo# zCX3I9&Qr2g%Dxe!^L&`&cM1Vdt9}dVAMi|H_$X(^zZH0Q!e0PT7+Hs zW|W#wen%wveHNGPcF)83>%t7bh_QH_=Ox8e%kw&v)IGHdlcgEByTR9mHy>+ng}Ghw z**mDk8y^~e4187npFC|R!>^0p5}QWW65czdwYiZcSz&N?$d3~k1Je|Osz&+b-W z8S(>;x%yW;F!*OLXqM;c>TQE8R#~? zF7U3WZ+oZSLL|1kfd1}Lj-#JkFG|2WpoVme-onTWYr! zrICzn8<3eCg)X2UUIFW0C4Rvlw#KXQ<5n6@yDL76c`~l8roeI_C*>D$1H$u->Nq2| zaBK7<#U32+r^Rmt&!B4dao=frwYr?C$dL{i9TUs-%S@J)xq-vsHN4AZ1$hfT6rAd$E&v~}ZvdhyU_71NJ+Sv4vu$37M~ zi6lBv5#P&+X~s8#~{KJ{0hcRJ1BsW+80AR6V-+taP zx2AsS0G#&^TIJLpCZme%PCffLS#>{BKVr{}68r}JsD2*&RF?2t*;s3dr0e+LHkYUc zw3V3lSfy)gliU3tt$v?;XukNx@xS8+zxzMxc#UHb0hxES|>UTF+7ycFy&2uD&%qZc4jPuBg03Y^2?O(2c z1^jjKf5(5b?uq+w_-Dsf5npPa7PGMNQZ>4?<50JfI}5Z_baX0iZMWwtKsW=8x6^oi zT9V9a&~k^hNvdvY>)LZimF(>Ivp=4`kTmhr;CeGr#7hSX+}PORrlC>b)f(Q_Yd9%F z{G9RTvsj=x!k=pwKiQgWHiF}3-!sdNQ+C{)x!as%Q(lZ>7LB-X2p${c4ai~-)c z@e!%@iTg+IUEkLFpS0x90W6;lEE2UBP8^F{sB*4U?{|A!cHeePjxaW0mV7Y-p{EpQ zkP#tKj@hW;A~UuzyK@5G#Pz465X693=O_J|=XXr>NsyI5Hy=TfO${%Y95CCTx+y~f z0X_NTbf%yMjGMZD;{50tB%jFNj({%q$U3}b2Ln4~lBa56Xq zQU*I+l}I6ansyod)%;ff0FtBPU;7Wng{g}gHI z{>6SDxLI6@JW;C1RB?~=+o6i}7t@nYP|}gL)9xMyI`Qr8?Oea?2d7VU`wjSMUm;st z)e7YH`+e8{0A9To@4=E^$(f{LndAT~^qjX4o*pVs`q1z(Fua>c>Qx6m$AlU1#BYeL6LU2t)2C8DDPIuVdkfCK)k^&H&2reYmdK<6#|4 zRTku%JnQ}tjZ;yK?fF%RB$mPY(QV*o(QXaJ#lQzF6;CCPt$K~#gB`R1Rd823!j6AG z^U}HfUq_ncl!`X<3BWkzds2pBDvH<8B}UCQa#K%v6Dt;SGtiEID%0@VYNhwwIU(`2 zcJ|5d_}59Id^`I_*{2N4#==M5H@KQuTNIGB;1~&BR{Hq$+>Iyujv86UpvoGZW z;Hz!R6kPF&xFJbXm(Ub*n&$C}T8JWXcS z`i_aJPho!pj2PyIMvd1V%%oO2g~25D6eE!S*xud`^s5-p+e zW|qpBR+8f>b~D`lt|{B->+4?t*T`e(Dp82D>MRA%F>AIEflc?M3 zx^2V5aceTmZz>;|9Y`v}+z<_Zzax{=^MNJ;aEvP?VW->Kqb1 znQQUlzRLdqf=hnU5cqG!nxE|P;+xi69e6$Ngv{JSdXfI}Fb^NRjRF4vdKCk$eBTq` z7B4S#64o00y8i%~6(h|9?w{QGO8W~bQU~M5;yo+fe`ZU+hh9JUX{!Ft{w>t+Ev~h# zU+s3n)+{u!qD3YZ%DSmjY)RYzA%@aPA$%$A7)BNrE=Mbo$9_BW?_D3kPaoTSGWe^c z_{&nTlI9ynad8wchh=!)B3C>T1CkWuC2^JNK0|{ux>+v~(w%h|MsL~G=1s3i!b(E6+a6w#-z42S$w6}yl6!@Rv1&7B^hyEYDv7cV?6yoe&PR!FY8-pXV7jUP{gp~+D z!<>$re4pXgxvBV3KeVp2Pqk@Uf|Q=#KtjkAk-tl4`$B^HBO}|rBV`& zfUbdX zFa{SH$JV`eYSq@)c7LCEk%7i=R~z8C=Pbd_ zwIvDB#;Q`N=IJ*NWmDLtDAZIxdns8)R%y9eCm+ji+rQy-v3QEt;(mc-%NC+MjYW(+ zEa(Qu?dV6~kLB}0>bg$G;+SlMGheo^NpMlM(OL1-F zD%#!68u^hk#FIqF92Qc_3vdnz+MwWmW$M}%v*BM5-*{hL)7mSW8>fO;%5WGH->;?* zp&qq{T{_rlRls5us?d{!mG7)wp1o|-S3hjw{ND$jhm)U9K9vPS7FZRN%fF1c0Qc+VIljA!ZXikdVn!rQB09I7emI`hzaeiaU#95Y-O z*yv($wMITbMhV7!dH(<;RjC;+cf%a6fXD$ZPIJ&?@y}fTHR?TC?JE`UNQxr|pOtcV z^@d7B>;w-W!QQ0Ofvh%YZms_Rkm`cIVO8QpVRz^GZm3fcOBZ z9P#bjH2_yxM9TpzWhI<~6a&*B@qy@lYRsyU&mKz%@XaKEXJ&hk@I5i@O}nt%iXy85 zx=eSOLo6m=nFt*6cp&2#6xn>onIXCvjhA?Eljb~HUB#TMbBwZ(5ha`avtg5nL9379fv*@ReYs4nGjM@jYepmuz zT%3|lLW7RE@AazI6GR%HQ*0a1$_l6W2Sk?K8Ytq*BU!d*@2Ny2wOe+2eZ-n0=; zGS=IRBLWy6hu`b(Rkf3;TWU6T(5~M%D9#fb{$2?^vBCVS&Hn&ppBhi`7xsMkdE<-A zhc_CAg%meDhKji4l#FpPL_dr+-6V7k=Gd9!pOg>3$(-&fEJ1jT>(p<8}?t zuq*kB`#0e)iM5-(W&AaFedgL*ShGN+0#^=zeBE(|>C@Au2Z(;x-v+f`7JNT@X!NH| zPe_;T&_@`_l`)AFV30vxfOlV7`%W*z(#lWNnmYW8l= zPTz@rAG54q2k}-rfOC1#RWOa#r6^QxRjFd@Vk3_vvg}Q&muK;S_4|Pj~ zfp6`zZz~=x+_4no)y-772-?p2>6o0?JxBEhJxw`AeD7E;~hcj z2;hT}#c=TcDiv^^%KFFhKcTp{13DQ00B`nYxAQ0vtl47N9RLF~%mTUga!h#>NK z#yU_Lf%!l@at}DKe&@-G%64xc4r{sno<1*Md>8nor+C5{1TsM|Z0305)xNyS*9Fu_Fl@%$*r8fy|&1hDh9<*vf*{5^$m%%9#becKQ&e>wXBy-u2k@RZ&qe{K7@h*)Ah%5>SFKr{55XU&&Hw@#cAE6cbsrxSBhZJuiY_n|243BW5C)&Q}{f{-voqORlR(C30w5ztYh!*4~11TNI z8;Jh^fSUZ`eAIDHKB)Z@h-m#|FQz*rw(hCb| zAd$||oGHdWpH97N?RbWrd_F;x#;jalv#9LTPR^{Fc6}OYuIZ<-{JY_%I|Z3|kHz)y zm19yLSi;AeGE!X7!P11)(`xaYWf?8wB;xGTMoG9})|V{Hxs`-vPH?TXfKPr$(!NRm z0D{44ai@Yj6{IUDGU}I-0LKA}#Gnjz`{$p}uOD5>r-6qhA{PXp%3kP}PRgc-R5?r=EPPRt- zW0jS1g-A?-%zw^5Pg=hVOCu^mv6WNv?U0NDalpn8IOo&x##Cy;)r!oc%>ZoxWRFfc z?0D~1e8rWVU4HB^`B#rjU=Q=x@vm`e?s@Gr&vyNY{{U^D9sb6j5hm0$`)g~z4{Ns; zQTVpPO+U_-OL(QplHG)nmAFXcxCj~ivSj(x;aPp>{t6NMBz!&i`{2tD+k^Hmvhg2_ zKjAjT;cL|_dSOL zo^$zr74#qMTmJwB82H8eDfo}W{u9@HOQ~w#4sG=rrPX{d3~)S(lOa|x$t-2#+UL&@ zqoPS7uq8~c=j(Dz-wP}w%JU}O2JaaqcTP?GOI7acCnaayEz|rn#6AJfFuDAA+k;~j znbS>Hrzvv78AU_*bF?LTsluFTNw_4Hr6umlKQi3tIyZ?Rn#WqO4-%)!A0Xs#TomLS z_w+sKJT>623wUbO+iA8aSgz)pV64rIf=OYJJqI0Y@2~hM_v~$<_`|^(AML;UFMKN2 zMv>yp7Jn0X^L_Hl)}rLYbsgJ4(z}LXb1|4gUMSKX*o{Lwd@idKtlOj^5jU`W;2wYh zo_hLYfzPKVYnNj3Y*WM3cUm;n?IhLHSKrUGy6An^0P)^Y;ujA?G{CM$c|xT*IJ%T- zLQXP;MHcMtuN&K%4l=S*L^rx#jei?OcVQIDu`1)q4dt$W?;zlEI(pYF;@^UL4yr!i z;do7i!8wFBVk|oKCuq)2LC3ZS9fz7;D8nvUVnD-;OKjLX&ED30sfbLW-Km_i_ zPd#&w&bTRJD${(?XWB_GvUr}4aj$$j@uYKJ+nDZcoFB8vk|aaa40Q}TH@G$WjsE}y z=J>=sY4Fxx0$*wpO!_E}T{Wa9CeBV(efAJXvXvF^XT%Q%+4!c#l4*8aTSbd|fQoU( za7J=-+;+}+;=Sko34g9?nlHrNFG?2PNxIT+Ebk*_IsMr(DfAff!1o5DH08wL{ijrQ z;Oe+mOTLHbt+bG|s!0XLNnyY_&piJCoYTV~K@yw=Y%e%HdVeaXX*`}>Zi)s@+_35o zUPoTPrBv}RiFAJr=y#gNy?U0mmXSQyQ_8?BN}!V2{)4Xry?o?u^hdiW*cw?Ll$zKg8EwD2hgo2us>{e$me`#hxTT?J@a|IuZv#TuQ!{zIty$@xC6_HP><6m>TB+w_$MF5Xz#yee~aG&WW4hRjjCxrAkx!6cGmu6 zid7!M=FV0){wcld;yXKg&06l?Pt&dBmU*2biX;V^Rwn^fM@;do?*-|ZTda33#vfws}43;})2RIpYcb9jA(#A5UBt*>^K z7No4@RFOsVUT>IH~f=QGU)Y z-1ArEyj7c$PxP-Sa-q9@ojZPY)Mee20y3l?{{SzwV|XJ+9vAS2 zm*H`7=Kla@Y0z35hL3NWBa9~fxj4=d2*=Af3Y=l~I~B zUFElS!Kq>gf@Hxs&wt9f9D1JG)fN{8j^WO7fS$CcYBmITQ}?;3mn|Sg1io`V&eLbZpM1#*q@o=$pFfaXBT-zXgiLr4_Jn;7Fgz@;pM<=|(WgVvx!z&QhD zDCgy6!1T>DjHf5J88jl{MjMZAl^{5OwTzlv`+aG+%Y*&eK5m5Z$e;%{2hEJqHcu)u zNAUBUj`VGP-+O>40XJc>z%(-tyR;H|gGqo??itD8VwLzHgy-IX5UAWfQNZ=2Aqxc^ zdH@GXa}Y)WpeW=qBv1rR<|+N$c?YTdDUy)n#?Oo>+*i{kq-1aJlh+H9DS}1{8|5cI z!g|%A*f+DN$Rv*9lss-magoq}pQT46%)h%Mfy)!=O~NCBNg#UX^Q+}SrLH-qP=03_ z2Y@Jw;Y%>s0SZeJdea_7enJ9q#Q~#7ZSA$T9!c#^+5>3*Y`!Z00Lf7CCm-`4y#D~v z??orYMhD`5i3rd9hp#(-(|4k<5&IYa00i+U{{XP>g|@FaV)7gg&PcSV;ah&awqY;l_VoDDYc$^QUd z4;r)m;iMIefQSGfk&rm&Khycvt-CA{hH{|rS7|kwZY~@tLQeMkq-VD#w4o9N5*YEH z?=BDL+qF8oD}r`g43%ZLe~~~1Ps&GXx2I@m>tAfoo?pXVg!Nxu6)w~>%69Au^3)ZYd zW&(g$KAk}PE6S-^w{~kwDHjO6iud~k0lN^^v((Y06bQ!+D23zxgmHx zg+Xx<00O*p7(MGXS;>T9k~$qf9#v` zkuwwMmn(0_Z#j*BHX(j(tZ~mKzq!xaOW<#V{{U%EfgUjU1Mv>jOQQJ4!=|lN!m^)6GaqPl5Q1$!u}=Jd>!IT4;FZK`K_JXH{md28Q^)y zHv76pJJ;w9(n7FABarNu3Ub`wo_Xwg*W-EiTP(XDC@+QTy+7b=aaM~`Ngn&)KL=<( zv={8PrhFjq@5DO`ZF@(&y1kOpVVN5ePqT95GRA(<9%WGA9E=l?JTt)m028&32-zl) zJ(L&vdrV&93#jq}90x1UarxI}`zicO@Q1=5i&r|&jxS}kv%1mdvo{giw2KRuibXBA zEL4cd;2bd-0Az}Pw0FTx58#)NBJkIT^*OHWr_*D+g|#iEqszJ&*r)-)F_Xft3q@#qF7mY2ji}TZF{OhEp2>7VJw$ZI;x~XWh~BfgbD#CJqI;=#Qy*rC)6~H zUk`XYP8O}?C!X;rQ6J0xtpFXmZu(a!-U;z;rX{+MNxPL;G>paLBkmrFE)NHS z5aa!ydi@#wn0_v48vL3U?A7BxhI(G7W^S(Kv%9+@K^78QjI=UF!JBcFcLrQ54i6;w z{{Y7ihFVUS;~x-wKG(b^_IfqV{meFYH?YW}YdoH_Y8S%|>-2%;Oa(-dZ)MZp-mc3A3E@hw)E#2ukdT)Rri8#rj_ABWU{@LFekNQLOAQs7$ejg<|AkfU@QF0jQwi+tNGy>nm%|!K_R&q=lLK0y;L!zSI8b`DuWm!JRS!f&+_S7_0+QLY|pqQ z%Bvi2C_YlW`3Wb(%O_Ms|X+at9a(j;9~tL)b}4lSdH>W`cE93H!a=wsY4#y46{bwajts ziOV?h+bx*!)Pd`sl-OCIn9IBXpoLJoR|E0qj{clfPb`tlBy=qvLqy6jLms0!_Ul;1 zS-XJim@UFe%?v?CHjr`w`ulPJ0M@FoMQ|g*gziVdkPsPv9^YSjw-v?B+#(2pbzWpv zAdKUuJ$rHJPrZ8Y>^1u}>L0YP$D2O|Uf3e(mJTmGMLCH4wy}`g3C0O_#A80B;Dt0w zaf}<0q!MiXGyRBsA**XX>YG-W5!x0s}%8D zq9l>VGI?zE>_+jBK)~m=D}qJ)Jd#@^Z~NHtj1(SlGm+n>E2Z#9iF_jmh3|YX;(b;~ zgx5k%=@B893C2R49HuZwIpf>wct;n8uZG3t7`GV0v>c-6K1CX^jGQ&PicPya?_`g| zJYLFkY`-Bs16=qM zSeM5h8Pk~PI)tI55k^ngren(S{0wqI^*evNpOb%T-`O8Y`1_`h5_n=udp{Q1#pS^> zkg?yY@*FpoX1(vop% zP?Th@oLbSR{d7Ne@c#hBju-JeA)2P-y`1MtQK<(`tvZmCjGGw(XXy1!FU1cS zcss=&0oU$)E8`sk=S|ctTWr?1F_mC_dFl^PM+9Ju)@yAf;9wtm{Xh7h@RQ-k#$72i zKZza|w$wb8BGT<#r10dZ%1XR}(Cz8$dRKyYqS&`3~gv0Gj$gO8)?YVffB#kNBPZR?%&tTn09lv&!w)pkRyFjP9?X{sex? zzYKq8y)pF<4@V{Tw>_~-jc)V(X3{=niATzd>IPWjmhD<%v#cEoO~PEyrHrr~HJD;6 zC00DK>e;#APXk(fKk%lL;OiT~<>^{DHOdKr69Q20`QuvCn@>`E0WTi_UXMVX*4;DmyE?E59~%c$dXFHwAEw zY{M{{jT#Yios*MMk7TuNyCwM9YIXiSxxPLS(QWM|kyd*Q$r`x<@Uh?#-xO`Te2}+C;$ck=j7z$9-Z;Wi)3Q1(5Wwi zk~0(BaC7ZnQ^!?fQ;8}iB_$ZsjnlgKX~8s})>6LRZht?0AMn*Mz8B%AIxdu|)Wub5 zQRSOeF;%4pB=x(d;l8%A=x1sZMm1Qi+AG#Lq8o$F|hMpn#fF#&-og(Jq*aau|+Zho0 za_l)DfUm}@O~z?%Y&Vhm<$0!(HtlY%7E-K~Me<7O4w8gsdGb*vk+}Od%BL^h&kJIU!mJ=dJ zp! zNjsr0Gw9FQ!~O~#;~&{$!`>kHZR0Nj#pCY<>l$C#SHfDxxh!_}V|>17i6lv)NTm}? z=EdcH^p#@6YX|B-;IQ8az8HKO{i}4_4~3ov(>y`tHL@w-DR4&)P2FERilF zi*#_r2tcaL3lLeEkH|CJ$rOjlStG*|vc%Zw+!As}zos~^(@)ux_ObY3{{RI%_(dn| zd;4E&Lf^)>nuIcF8i$AUQ4F`S+gpj+>Ny=Iir(uG-3d;}kjHVdA(mxdxyh@`I?HF2 zVk)ISbw)i>YOZ^AJ`kmxs#N)wtaM|42Tgf2%Xnl-{3r2!W?@vjV)#FwTCwR2} zC3SVWcIfQ;qF`Wqg$in z3e(CLr=bWK`vNo8yi?%ch5QfV?}$1_g*<)Y$+UeY`tf49)FFvoqs{=`zo0Ba@BSQv z&r1CcO*`P%?7Q&Q1>U7=pm=*ujeFw+aANm(@)@k2x{d#C8xM+ ztN6nMs!k8fw~k^20)zOE&HdbG{@1^^mZkABE1wQ&!cPzQQVCT>gAi_YQoMsE4-1e8 zjBq-!J-##Xw~w_y8R{bEV_Vw9zqz+9xGCo>!=Yo-J%Hl7vaGI7N~Kg@GZ@=&hs={- z@h+$0y;5ttn2AX8BaJ#lxg>G@J?@q{{RS=#d`AW)%;dcawb^6Teip3_G#=!>-VMMShNwzHY$^ z0KaiSA9uC}I3C`Dxp=EHj9{>IzJH=TY0H_VbF26(I=97t4Mgey9wM^*Q0hS(V1^ZZ-71g-fk!}pO8O2~(y{w&WbE!;jX_XT<-p{QfcyTnD`f~Vj=gYr z>x!ct&Iv}(T%6|_!Ov>79K4U-0;CLvV~{`2YbiyhX&2EKA(PEuH(;(X3FOq-IKVqV z_3A}cj&u%60CSA<`qcaU*w~UACmlQ15>Df}EMc~11wlXF>qWFoa6!T5vuTFIfg^5DC-k5=Qb13X^(5n;r2}K_W01gclSmAKcSv*3rB2`~ zxW-fyllfD)1hL>}nopgM5N!bSk=B)P4#JuDG=Le0QQwLNLEDqS?L&gY0PP&-){ru~ zZdT_!5-0#-N!$s>ae+Z@y(YOcjrD@!$X$HUu&P_ox1^~GL zcT?+*D_F&j=?-@?N>e;zJxIkj&5VLU0~igTTz9BtmO$AL!tsvu)tES023tLFbH_^6 zLqc1*C!7Zv8O|^%ZO|}a@)!9C!*h}7Y7pQz&cNlcPjkgO)Pl~UP)Lk`0O&DItpQ0~ z{%<}h9Q;l38THq%KlF3aN%2SiN@s|_>^pb=0HdCY<{A3&`vlq#_+hUMI+6+2wKJaC z`+e1W29s*SH3yu5$Tw=OhP>zg3F)Jo-`HouTmzj`RBoNWv@Yw}Ec8FJpklo~ z=_B7Y_EIO&Iepq1gCRt)r#^!LZ6fjXE_}OR=HFv5DYS&02%sLgt5wFB}U`? zAR3SC?!>d>BcKM9XKl%Gy0MRQ5W#jw$3Q(ze>s2fcpvN%{{RktK6qdD^YB)S<>@{< zwp|m&j}&-V^?O+V0COvG)@xzsy3pxOHE^PeDh5rb>KD3I{5qbScO(}U(yXvg6BTIPPE~;&2p}5$Vf~jsXe~ST znD{NG_|Qs;tHHL{#8#0VqASDzvjqo)kwB1k8;bt`39reF6WgtOFYHJANgvtU;^cQa z?aYJ2de~b{eNjsomp{9?P&%|scM;W)y8i$cM}fpySN5I-@-Kcmuho4205U2^o;E*G zFJKa?I`hX6-E_s)!buNGfSs3jFg&ujvVeu zulOtR-xIhZw-jJunP4#$aQTJ_PI!8hoUeDuDbtlW$||I4s!@ZB@7GlI?RDQ5e$Bo+ z)_fo04Gz-oXju|SqML$thDlL^leDWVZNNN$LXnE~JDZ<~I*)}dFD&hBt@SA|^4ZHN z+7O?WcOH#^I6k=TSiiP!!Fomhx0x%HV*Ql-st_jsgb-hZ}3lA^o3xU#c@|J~iyPLolG~U~wwdDRVUMqm1LNqctTLcAe6seP)&ayLQLB~hr5zqVt9)CZ=y;!v3eu$-g%XShJx<+Y94;ckxJgDozImsva zRqwQ{cHv~e8#gCFLB?<~{eQdfSvJzecZawQ{<))P2P`{-)8E+jtI^!WGk*Ru>Ooa9 zaD$ADk<*?z^`cv98X+Qa4D4epPyvVAf}@ea&rH%>EN=?23}j)JRpj-@%0@*Z%PjW= z9_5O#1RxzcoF0Q1C;aqARWdM-WTi2)aoPY!LAbsM&r&<}#b+d}C9z{b(e+(lS%XQ_ zwE1psE+kl>h7ehqnS!WQ13CGJ9`*YP`wafdR$sDr!N@eNM#?F?ORS>a_^S075Rw$i z+pY#SLRC>-c7|Omavs?9h?+aIBri&C8tj$l5XA^Q~x} zDJ{p@-vCTk4A~@P5;}3yK9%HEi%&zkDLst|?iH+*NKxGmRzP{d9S7hIRJVdwjJ%9D zg$xS}Wl8DpkWE&;xVOEZ?Qt+FKv=;!$j2EJwv${f`fbKP&lSmdxfGj< zvE29@$8V!%-tp#=%x?>BiVnln70wA59WnQdX>#Uyblk{LtBUpq-t$4;Y< zQD0}2aeYiZDp)MalyMknT_rfDI7Ue)$u^sXteUrEyIb(yGVp5)kH+Hj&KSj2%d;xV z5^|$bP^DI-IXT7GqZvXKd9Nr*&Fv=IZuj%QjDND{!`~I^7ZK|AM@-bz9oL#x)e;lV z=Ui=kzH7#{{{Yzo;vJM~JTqT=G_wl{v*|<9}<(&7-x}m%X~%ULN{F=e#W%IcFZ_)$lllMx6PhPFSinWvps) zJSwm3?I^aPH_W2d?RdF-BHy!T##p9gZvtF~#y2F;pL}P22qL;~3jWBy5w65zUY6HQ zh1cduB@;e4$=bXQ{ZHl6zQu0<_;1W4`2NH>&&onzejaDNM1CIlWt#`$HHa!ceW3wA zU*2@CFZC0MlDD#l{tl!60Lu@itMsDa8nLu-6<4MQ9sdB#=aBq3_&?yEgmz7DqT6|s z&BflMDi1bA=jF&8FRpn(v>(z>UhNDp8@aih+ z>0PfaOivtm?#spUK^~W9aS*j^Kc0$97a)DmNXN_9J^FOdIVf1vqp@9$kcD|GxQ;pJ z+wrLEZsAMidsrMtHdWXAr<@%2ALq4CsY<>~+apj=06ED3j=A*z01Eh=+c(B%xrFh! zww+1Z_EuUXmW#4T{p*8xKP2ED8^h#TL>w^mRJl~7Ry_J%&;orB; zmv^gv%-%NB>{KhRkY-i`hD0drq~p*No_qJN$wg!;L2NgiQwI#W=Nn1RI=4JvSM8s~ zI}1+{czWwg)JnCr^@G7WlnlisQj9QoKPfy{S*pEu@%ci*+J-kKKj2lCf3Nu z%aF^*C)5w6XIRf~1cj0){MK!zNf#$PdUyV{a^0hfGjBA2GbVBYz#R!0?0WwI`s>kc z-p4a~maZOH_jxFw3}iSXjyXKz_`jV-7;bkCGBCuV91h@eI`RCyc|usgAydvXEO|XL zI{FU3iKZ!yEm&_PnE?$U3`heYpXJX^m2hn#HDaZ~h~7N12_#^|C~k^-ecW@`pVU`H z@W0|Wjy@9j-%I$P@V~@YT2{5BU8J^mwzC51CI}#c2xVYOv}_p}SgWfkAoBX0vMjO) zZNmK9hT((j$sPXytxsy@)V;)NuFj#pQZP<(ex3gS3hJj)r8cJrX)RM-AAeJy6NjOT zqdFLfDZ**SQIbwFi&kl;bkkZi^lNkU$NmXX`+fe=AGAO1A${=gC@>}7H>85bK9M$VTL!6#2;vFfS-|Bi1~6pBK@cT0A;_4zp@v_5%8blg6Xo}KA#2s zjpmzn)7#!dA!Jz*nN`zw0FZ#1$Zyw4abKqbk7X=m;_mU?z&1}JMI?S*)5hs#_J?VN z?w>k|N&8G5B0No(#{-M3{;>w#IZjQf^DDSE)3?I8TYBEh&-uGN%B%P%#kF{&BFp0} zqZ0~n#!6+*MCjQetA@LTZBiB4rsCa@+Iw9ozCUBnil4Iw zi*;Xxdi+;Y+QD+|plk7Iv%zhxz^NRip=k!zby<;RMQ0Jnb{2dFKR0|d_(Ab!<1V)+ zh5iuyG}ruDsKaJ%t~B2cXtUo>3@FHkR*n>A3K`K!BLKRLJ| zrGNo*pOktJ!n+UI{{Y}Wi~c_N)-MY9M$P<~h?7;+WlTk74spHz0KCVLtDLbo9A$Q-VI>@i6gGQS?mJD$h&m-cG?lD}ua z*i+y;YaTD~_K9(MeWR_-o`kB%yrn6% z6qD7-r2IejSNIR7cy`-I_zCdB(G^*oBizcM$mL7RJg25`ti%kQC<+b$KKpU|2mD)- zLh$&H!)i01`1-i#uOI4umE1+~gZ5_dPl0W|HFz_@z8OyuT-v~iqF6L>TtOyXWlxqo z#k4%Y%dtd(P-KH0eAD7xa__~yC)7MgdkRHysNYL>BEqL&c^#d$f=L4{j!8JjTKR0t z0m8gJN)H<^uGgyh{{VMU!ZD`O?xk zvqCkhRK!M4+9@E^mSBKQ6hE8QX=GXDVW zD=cX!-m$#f*o2WonRYB;TY3%^fgTAysjGZJ@ioooiZw&1U24Ua@>$y(!;EqU08lUj zi~t4z00CGp83>WG942=wHb!{u+OB9fR+e`$O2`bNKzAOE@A~nNTJdw9Da7ZM>C%i9 z96TSnN)}L!+jr3<<*U4`o7Zcz@4P$k+b!To)Wy`REM8$>m)l~h-PNgO6=dY(;Z1U` zT53%zC3j_S+dtTc#W#K`{g3_^_=@h}$vxMAbdNStxZcZY5;J}!O?r$|F>nAo19O(^ zUIqUE1k=>pQT>K~8pzDbQ^USl5-=q;Pb_|RbCOT3E4DF*11G&QIRv4I z2LR)djM5f6!tLPkocdA$<9|lS9qAtg8$; zu{`7HNQ5C9@srpdzO;^1;{=|0^rV%-$+$S=V1ZVZ$6>7^hb3I8-O!)u>CHm!!zxY* zB<>$db25*Z0#AILW}>*5hT1w~o@-dd8%uGrONjO~i@1#7o~IpZ6EPVZ@;a+z=A)kF zS#yv^cqATw&ox@&Sr~kqfCDFSt>I6&q?N8&d0oJe&5Zrr3W28IDF#BI@CZB#jjjx5 z4&)vQU^oPNR8n04cKMn9=?6HiqXJvj<+DCU#yL2^=Rb$FNnli#0B|E05x z>#kk@0O;qUxPzGevi|^rY*?%Q!yXM{oVDhtN5A|`uEl$zX)7XtQ~l%8ypR3~jcn5X z!oCU$K+M{qtdVCOvZ+KQNoIM|rN&f{%hEg4XtC?sLKPWh=7*@;j&&tqH@n`c6ekmXhhf=5odrHvIv(sD=f8jO}_ zAoU!8PZ_7WNW<@aaB*4cJp_&j(d2bZ7i_W5Q-Cx13STS{h8Y8%O*siaDIAWpV=bJX zJ?cstP-k{J9=NEbw%)8Ck0x;O5OU;ZKUc zVUG^{J^0(>yNwgV+TX-WX!Y5&l}U908rWYLWr7kxoJVwRT0|Jkdx0yQ*YgwM&w<_v z(R@W7ovr!0?D0DzK@8?NZX7Vjm$YtSx74l=UYI?2iNdraJ)O1uulEHibF)ckc*O%5 z7wm}x_+)sJFWK8gwzZMNMQ^ABB!SncVo53vM;YXDF~K~4#Glz4!8#{_^+~la zg{`M+5ZgJmp2|tgTmJxsvSlfbc?EHSpK8b0rnF@{qvO&DI5emMUB)(G@cW=nUiMs9+asEh;;SgS(rB4lpUjxjzIgY z2clp(U4c=cvyUXLWmFa}D*q{MPpt^Q@M$+saYnk#JR%j1a(%53hRqGQ-5)2K}Z# zXRS&6L#Aqf@Q`>1!zJ5Nn(3HHH2PU-r%#rAqlWI=gt9ZAMbVSHx78 zEm}_SwJFu92u>|crFNX3Bv%3AZx%o#eNO)NZAZf%Ahj(Z-dbHc z(De->(!%Maj@>m2*;V8Uc9#nrZWM0}xHvgNaf}o6cfs$3o(uRnrQ3Wug3{n?8t#WA z+RmYEeIQv{0~$WZ42>Qd?+8&63Np7zzXth5o{{U-`*}=XOc!c=B!kT;k0Kz+eDG-$T&7_QZ ziL?AUIYan$>+}85lgwR;2trD$;IQQTdw!#jPt-pX=-wXjW|MW{PY`LzcV%@6itgTI z$b6ojs@(`9p&9L81b)+hv%a_RE5)T_Lv^BR8=p$QWxyjG#N>35K^=e~WO`S|aQ<@~ zT|Z-0`|I)l0D^yB(WvcdGtObU+Xm#Bg0qJ#(sKU*UWb5xoleYK_mZJykx6F91xG{N zdsQhk_VSgth2$NIGB9}L1JkWt5@O%?5aM8&M=!_VqwpM>`djIvI#Y=DJ0py);vBEa z#PhfUdF$Wl#(gU9hxA+Be_w;cx(<~zw;GMo+uK>H7T+X{6;uwqmIIGkqa5TYKuUby zD3E|h11-lj^^g1#FW{VC5&r;aQRD4m6@5R!)-8MzM=hJl4#+-UoP6)kU4sI-YgTk= zw#OLrJ0GHdviHEv6ZUfWBcc2y(pn<=b+VmC?Gyubc)%H@@rtGvIrj(c;t zw(W3Ak)5Zp1NoW+B{qbW0U0ihoUrxhxa&!FCRvrY5Oyjisx}|I*@VXXZa5qbd=$P z^dJ&=>&N5ut~o-|vMEWmsdr3iLt-^$Vhe0NNj&lSbpBPTY+hsxt_fU*1E}NtkHWKb z=88MtB-~`j56Z)aZZnT=nEY#6)?{c}NaH7RWrtiJPk!{&Z7oS}nO$@k>vFI>ag`fM z<%e&ns@C$vFC^KOEIO6#{d@kkc1a?b6pkkuY_>;CkN&-9N46`>yKtxo%KDss1Y%`ONpgd4-}ibDLuM!X_jW|&1P3!z!K$KlB3l6@k@IYt{4_mwBdrBAJdvhfe6g^ zAy+tDWP_h?Ok%BBY=OJjR%qr|wYi8G*f#JObRhBmeX6XJ+N7%1t&j)>NWk0B{*>vY zk#mVu4abFEc^|K>MAE@;6a|zp^ER`PI30#Z;ns`vV^>=BB84t(<`0qOugFw40AzlB zYTebiQ8l=79F-Bt+=0Nyul1_;w=l_SM3HcJGPXD;kI%ULYUPr%Xq%?DE~X$_M#k&Uu|eICI-GRQ2eo|^TxT9-vF6Pv5?bw#O1 zLSBzT=am$aXJCx6T&_uD#sR)}PjTc|lJUOdcYu*~vE^IX&KILw#uk9{cDWH`Y zGCWL#G;F5~8Au?2axqonhy5Aha-T8fl~`>jpQlc7{HdR2k_K0Rx{;R+oSY1fpGF%D*c0WWk&)YBdKK-=6WA6|ADgD1cXm1f{-Ye2PM>YPZcckb( zT$iqf^tFUR_kB!WNEopIIZp6<#I{f{SR&reB!=`8OHgJ6AJ69aCM^&MNz3a?aO=f>nlC& zzNg@|zK{GZQx97O40bA>0y>hU5I!sT=i=4O7v4Yd6u0kb8H(7nNRh+|%Mg*Os3(lJGH{@i zv||oI7$CHmkc<#IkN*H&eFN~5;XlA_GvI%RynFE@Z-YIs^(Zc zikrk%s;E*}Of>~r&TyQq7yGLx2ul9?zSeg>c(xAG`RM8ZZLX)bIO9J~D*GeFC@5sg zjE&ykg?obPAF)fxcGjN({7*94NLJ!&d8Kt2&p8*JD$m4EfZhi9jqr0y_@nzfd_UBp znroR_Ye{a)$kyxTtGP)fea;>w+Oe@?zo9tBuYqy3u=OW{#Zj*+aprKQ>Pl`Yq}ooJ zy_;9m#(pzpnQk);M>@*!msjqllANQFeT*FVdRJgnX)-K zZinbTl>Hn{EZ$Q`u{(i+a0Ywzs9=^9w{=EbsVA;8&JX$e)cOPxT-?a8<~EKpaK9bF00hg3f5Qv>IxY%_(=J?Rpxu9R-`2hQ8Bjj(MmCM0j@a})J6G*EXfLY6 z(H&_>(8{pNDd3-*JwIBU%&G^N$`p)bjQ+Jk*5QF&q+zj*pXU`hn9cH`=R5`-O;n?) zrQ6)RlC08UWy1H!?dknLD!mNqi@o2D3iH>n{c99Ww2iVUBZ2GgD$SklSz1B2+jkiC ztS3(BtGgS8!mtY5V}K1y3t%o6_;|?sqtMj_d3T9ig$gh`k^cbJr?vX32J<*7xgg|` zp1-YfI5ujf(7PqbP^%_R`QvIlIJ6iw_LX3m> zR&s|y+WQxbj6n(5deSp1H=2O)!RM`3ig#0zMhD&_oc{oy=~6Aiwo0DFjPX&)VI)Y+ zmIurO&}Y7CDP{qeau#B7z~iMddyg(L&pE-Si=F}Bl0h7Is*GKOri70p6$8yY40Wc? z;HU`M1f1onv)z#?`GNdDT7mCLmuPUx4=TeR`TQ$b(`30>=ub7rmRSOEliT@Kc_&iX zd<=4Odizuo-0o7~lg@heH48`;OYF{Z!*h@Ry>A+1N>NWzYl*&LA;B0Ulg2)^U8j{* zhnRD_uI}F7l{Q(_#pSphw?9woRgC5`e8LV#0C8K!1i8M1?<}dw186zm^**Mc+=`6H zARJ(0kxi0DbAulr-W+10c{3(d_Vxb&Ju7Iyh`BAtJ;+h!F#a6%#%fJI^DQwB+*`f@ z=NwfVf*4^2bJrvDJda8=_F(rpJT;975@O?g?tP^&Sli_dk^@T-Sw}25Rlz+ ze}Lnl&3IS*66HtyFrR_AJ#{~){Dki7-q>d<0V6)0d9HOW1l7nyD#z}QpLkO3WFf{* z4gd#=nd6C6XXW-My&|K7v^Zip#~H;wVWnkZRz0HtV?9P`_+$VXPvR*&&Bt!S2pQ&` zkXQmbj}htc-KUT}X>-6MnvvK{V*vjEI#2-NMn4KTIPc9F zGz=KvV>FDcM+6?7X;^2EdNDvBi9hgjCx^UGHx#iPg_;bG2*z+7RPk%B> zK=pA9K#|Yx*U(qTe+$20p9txia(IVchT~GWwz%`HA}UfNv@v19!sm9};PIcOev1D9 z!NT?ynt$yZ@Uz5>@y9Iw7`%=*XBj?C)NP%os2+At)O0w{%^FXRA5-xpJ_^#Uo*3>d zW_ydN6|kzTigs<`7>D4K#yTF9t2pzk_BtTyawQW(;orioaQGI|+s1EYKB+a1b(hm) zF~a#7+PN7hyeK5(fs#gXUQ_XZ!B#U{h}TXRDRm-8H`%093Mp1QzHm2mKu=$oU=fl& zhncRljWS!PEQ&|9Vvlo-n0E4TPD%Fzt}DyFCF&5{_+liE!Vt-L?RGrHm*t4CLbz~>J&nwI zoG31BFW_fD$j!0xtgIMran~4QkVss9QvSo=wbr5gH~7)KKjK@NrSNvQaTwMu;~SpQ z!N9tRAL0XsU*l1Zs<{0owA6I%R?#(0GgXe_+FNy+*6IfvCyAGGI+4KzNF&r5{Hq;? z%=l9`jxPgI!bPju32 zRDC7zr}mHdhw-1|w}bRA6Gf%k*xOy++$1`6yQiHXvA9(%UEH0j&ww_Mn3J3gY51S^ z;I;6ljP;!hSNLZQ)%J_1-)Z)Gm95a+$|owSQI{$RjxzXsF^@PL0pgzk{CT1LJowF^ zc(=s25~Y=-(81Q{++u*P4vHL;z*I91^d{x({*Kdv1 zsU$|&h$3f{65C|AO~|rHgJ5vPah|5XRpD+X<1FikT|5S153f%VE5@Yh^Ufa1l8+>6 zCx3b_&M7q0*H(Wb@#g|~S;Aa+tLK>|SCmr0(@~uYRoyCeez9HDbtgrplZ8dhrrMHe zS=l94&gaD*Fwu066aN5e#icB#&AHRa%1;vmL)c z>DM|;v%{tRxA!LOyd>>R;iM6x9tPl`_6EFP$Cp18^aw0|4E#vYu5GmoS(00PznIq% zvf?FHOstqYmvXTHji;40(s)1MZM}_)>i#~uTT65sn$@NGlyxRYKX;t`qn!7zkIHlG zepkd4Yb-nGp(N&}+#Z@rF6-U;(vs0{4wI|54Boc#ue0^jE~3;n9jyrj1lkDX1>JJ{{U(qhkp_D{{RsDIQYe= z#+o(Op0L?!cKTbh$k-v>fD!>9ZI2rW!iGY4Awl!siXXDigJJk_r2I|&o%KHw_?p*H zy?c2rO^KYkVny?Iq)6zh@vT%y?fCt5S+{blXb3G^fjMMB9>A zO-Utm*6)AMICsTd-zLRU$oTUGEG{X|G^2@xXDl6QI(+I;tm7C{ljV|h+jeQYrF9?4 zpN-$Mzl^>G_?GN&BZ92?qeF^vL44i%h&r5v%Df@cZKTgzqfjyH$~F^&5A$m@5`hG-5>zq%yYF1E?dhuZKQA{>=U- z_zkH6BTapyT}ii1)e+;o>!a7yY?8R%T;L8R&V6`pS)plxO`O@ z)TK2z%V{Q>Yssx2$n#5SF3;}JlmLZ4CRhQEGyFL9>&`3nAN~m8;F~=M_HxzyS@wKO zu6TvNe8-aTxrtzYa=5w z{{Rm50AshHZuK0u(Ek8v2#YaS`OC=Lh6>~_>T5YgBy5wEv?PN1+FN9b9wmvCxL3vx zV}smakEo|ja}>7^cLMHTEQt;|;P>?EIjVNjO9b)>m~x@laqXTCe+scQk+VBmh6qmZ zg!Jd7amrk}Sf^rWr3O}+Py%B4qg-H|gZcjejce(bg`6qdwNnj_Gn0(t*06)e<`Zy* zPTX*F&QD)z*wT&d+aZqKq#`^Lv66a&)2&LAa#u6AE7M}#?1U2DX&z_aet`Au-==X{ z)5jQNGnOn#*cH2Aa%;){F8pfv%kYQBTBf6`czkN!6GLtG2y_nuYFAo}A%clZm2nl; z3vuKybWq!(Y-X$9{@7j<(>!5)Gm~jmKzOj-gidHBo#*|Ac8T}VzK0@7pk$V(6fCatClgbY;p)7 zbL~*Y43kDwO8)>M!Ei9#b^d0wwL44ZIk|#J$Q+qr#(R#q{{ZXOZKK5=^k8{la!xXT z4xRr1N}E2VMf;?WwnPz30Q3ited@Kf_3S3-G@E|W}UpU!2F}ACm&zdvvwG4r_)Sj z@o+-2sR4RqXR_ls{{RZ>rz=NNB-QjeuZ~{?ye0cZ{3g}D8F~Z30>&$+8LL|6^r`di|nJv693yuICPI;(rE_N)iZbrhN zIx+J0?0X#4!gh(t*jN-E-ALemzT%^qB9x)G4BNp{83)i4)RH|1Z%lLSMYQ%fmF_m+ zy}O8}iB?H40|ww{p1k9(e!V-?8hol2GaMCk#>BfCX^sl`%y*WP96cO93 zQ7QyQ7{CfR>DXevXNYl;ij4&I-JJ8fb~}%QpR-rQUyd3j*N8kTbEVr^THS~|?Mmg; zO8Wycs6=+)G7?Bp^WbL(9?$VV!&=wDZ`rrPz8%(@%Ifb%*WPQm?B_ck7yDNzTq_)T z@q>}UYCl(A#o?2cNlwZtbsn6Y zzx3@{^DX-I{&~spd_&@Y8P>;QGYCfml;$l{qY5dx%9b8BldtVGzqgk(c^<1quDN!! z?}~aX7n)y#{ut?!MAsKVuG;NE&Q3@qlar3DdSD7)0epMaJ_>wC)bDgS9{&JFxVX4i z(VqF^g_!OA@?aHL$|&4W{JCB-2#3WV548I|X7j`cNJzAMJs*5iTF5hTDxd)5bYf0@ z4lB)pOQ_{s>?%mZK6BI4@aDc(`%^Q@adhz%C&^MS@oL&z$)=ipozv=m-Nf8Il<`(e zS1QeDsTRadB;B07#66 zNn_lqK2#(|g8%?XC#fd9-@;chSv|$g%!C6A0zmt@K7**o^RHR`yOg)aJu_>6%()Im zai8>yj+NuzXK1v0`^#gO`#hrFA;D(BDh7M!1M6PL7h$PW<6KobNkuA~C1um1KRoa= z7m3FBQC(NSbq_RcE;H&6>HO=2_($QWd;##I!#@w9hiK6>3mrG6G2~mv7i{-W*1D@L+%w^e z4uIp@zeAx*l7v(2kJnvB^hU%t9#}Fnav1FZcJKV^)K|gaDo)lTgZ_Wd6_%6vVTgWD zP&)n}=e;9c%@gj8Pkqcrd-3>FbneEa<v(v2w$jTnvDsFjtbNr`D;g+fH1N zz4CoO&-19HxhTdsn4Sq7ay_a0K4uPD8V`E=04xCPq<=cAE~65bL-K&YWE^Agsz&Z< z03psjjz3C`qyZ^T+Fm@S&1RL;9!n8{#6v08}h8|Nj+5e$N4o_N&=Ai z1RipEG{?Bxx-eVr{Kd~v$E9qk*qF)Kvn=lG@`6rsNN#F5CkpJNAmu_0fyeJ3e&w4|01Z7DKGs7<* zPyYa2bLrv@@|g3`e8U(Aj>eq2)J_5i7##p^LF>=yTYX+#h;=3_v3p$6Qx{{{X=*G5-JzCE#rT0Mt+a08W>6@A>XUa2k!lKX?p^QaQ*S zy3&#Ed88vN^K|#1RsaWnezagH&V9~JFF}AP<3EQ=0>T1wka#qwfyEdHsGt%D&H7L< z>T!?8mXHuebNLz$de8yWy%}uhf;&?F6o=CsPzUAb{2XoI`)kct;Qs)O+7)@AU2e)I zW&lR7_BU?b_7BKc%>Muad^@GXpus)7UTw^r+%`XlgSW3>dUfaPUub{9!n!Kmd>r_7 zu53bWJXLiL_e_NnS+P9!&z3%?9r0fg{?D3y{jQH_ZC7*MT0E%Vzz*O^B!3a(pZ00a z{Lu9MP1o>#XSIJ|={oh@%IX#nzm;yWnGwM|4?B4o#^KYZMtfI+H-j{brq!-4M4nsT z+*>>{%aDqb6a`08P6@|3!0F9;J-5Vb&kB5IheLSMIW;X5w&_^3Ei7_pV8(DrlahM& z&MS)Xca3gz{{V?vouqm!Z56bz%2lNO)JAVN0RbHG8)wjRdeZw7rSKf}qSUn+MHE(sAPzaN zt-s)&{{XgSw}Yj&_!04ISY*;=pB@_3rUbL#a22LaoB1tb(vm6 z=3+?LAeI=Tn_(&$!1K18l$MEQX80EduHLV&~OMO`( zb%#>7Ssg=V)E$gfu;5{VKS77_hvKF0i?yTh{{X^L*<5LoPYGG$Of$-G3acJAhalv2 z828W4)W+1xX~|(HJUwV$KI3qxo|jFgovyW8JrC$CZdt;-X_C|W)e2a=H0?O4!WCfj z(n(24Ud`Fs>g@FC{A=JJ+Q;Ayv#$7~L%&}VcnialT{ZZWA~nH@+`0rISV{I#y;%m( z#3}wN`I`4l(R^Rw%fAkIhfPb{JH#S;dzkX2PfYg&5HLqPSGfEym&M<+HIBM2{vB!l zCen4wNaED2W4bC3FPH#r$ibr`Srvg&4;+99J&%h&XYYXD8T@Assqm}f=ZAD-s8~i0 zfu`F=?<9()yw4;|jS=$p<}iaB3m!{X(q@@XVaJ(v7a50~9%W4^RfFZHI-c~EIjKp; z6(pddotu`)+ucw3_bbito&)h7KP=4f^{Hf)r0U^nr8@MeTAU?XlZ7=plw#`2a*g7n zyI1$*za{?wXbo zkH2MK4){01z9h7pSn$n>w$}7%}hZ$aQPDNpOk3^Tlz9ZCpDXEwx zlTW^$&hf!58?G8YvMc3x=?~1Kt^H)x^pKNl2 z)I*cB1JJ_xM%~+`d=?)(l}e8bjrVFr+V5+v8hxLc{X>^ys8^v1u$MF@l$4iL)zkE7 z{U*7Uj-io-RTx&!BaWZ`x^>&N<*(VxaMAGFz`)P*@9AG7e#>9*Q@;d!Q`PN$7k<#N z#o(_R!*aIF_h83ms7=C*7ykfJAtYrREbg*H}0m|N9LFPdU>x#Ke0o(?ZNS

c&<5ZbW~%?_1L=28%CH+0>G+~z&^R_jMtBT*&nlhlksQaZ;3p0 zqxj#)x?hTP-w9t!qT2Xl#+R11w--@D(nKE730YXWg~9|ck)6Z}^?^E|0I>|p2IUzl zbIyJ0dkcWi8ABo3a8 zo)^~MFAhiJ>$!@my|hmpGc<%aOr-lm5H~Q%3z8cz?YsL}d^G)?{7Lag_P6+Z;b=Ae zBI&I4FNZqMfHf^W^<6SniZ*+FMGUsHqxp_XGPHO`$;&Q&;nXE-<(ezG8h3b#v76@e z5R4UK*~q}b2O|feuP*V2>`(hJd_vGJJSDDtD2m@kx^pG&kD_Zgx_zC)a}NzN8cr>Ek-;f$?OzU<)qEAh|vCHVLIJNRez+woqLWufTa1iOv2 z?L$(wD|zATn_{T8L7^F$Gb*aW10$=9BAntM*_-0W?75)$3eQ&kp1fD8>0b>k^?B~? z2Z%l!U;hAQ*~ui#(~E6Ip$=~kh5D{$8Ke|xm8jWkBdt1yCah9?EypLl2 zoxf*Ki2ncx^=qb^#`1W}M$|k-b*gA{L$2IF_EvW+9ykmyk#LcMNW+z2+#;S&xat}# z)}MX@5htXGK|h$Um49mQfDujt$TH77 zT8muR-bC0aN=Fbm`3j6O=J1uVG3QM--K@O2^e#1H7iQV|DIVx#lFS9c8E_ZV1drCU zuGUDe9n1uIpDxuOcaTTvpU$cHOU70nKJbQ@{Da>*d`aLT0#Am)$S&h7f&pNN<%Vk84NRm zKM`FplI5_kd1{F?o0Bv`NKhy)%Od`x`X6f0xq{zNhDh%1BAy$Ic1a?em6b>#fzM*O z2C}>@YZO-ULi4Zg$oWa-Pb2w{(zYX#KP6G(a>>vU$G!*WTEb2(G9}9I6Y^vB-2Iy; z{hfX@-e?+SoQvV@QBCiM?$;*}7Qu@?vA4=crqkZMXxcQfGsbojPUYAKnd&j?n*C?} zx&HuXh<-o(3%s$?Eca`AF>OObWx-~3AH9P;K$1N1{otOJ`8{^1iUEOKGLVW+2V4)s zoPBaD=&;pYYDx7x%F%S?WKek8X_DUEBQ3WkC5&MEoL~{gG1oOx2?TDD%&f|)NMe}B zUPt6TYR%YHB*Sh=87-cGo~NIF)mnFPASoM#Vdg{#9AIFabm@*b>(aYp?;;azT)iZT z<}tXGg(PK)hRFki`s4hXktBPF-Z*T=8~2iP@`Kay#(szJs{V6`Nof^#j1td=CnWRh z&maD(?|uRNAk;oB_|6SBM45E^R#~-Mj-nWI_cO@c#{dth>}gl4PK6gay%93CuFrM- ziM|!Thdd{7@fT9I@~(A%mc?0dw%Kv>jP`a1{`eT{UrGEu_*wB2;)bU#pWr_XOQ^k? ziEgD6BuNsSDBSMKyN);+$pMHTDaR{eX=SC`!=~wWq8nJFSnbgO3KU=uzyZ>}pZIa% z-81%h{hqbo8F+)@HLa$j<1Y-X+AX9E(a#*7RGv_YgKv@ZaS>vwJhRS78>{Cyb1SFC z8Qn^_rsIdF8BU~>qZb&czEqsFyrs(>p7*t`&&qsLwp$SMB2V41;cDk!6igqQ)fUvH>9D zAQ8;|%|8zGdAv_I?GbI_tvACu9)YX1s$1&O?pF^GSwl$DMzLg0s)sN09#n+^v_EW( zN8mO80ETRJ4*=fi6YIAcyUTMngDh&)V4UBfy&_^khq-t4-XTZxB zcH}6)z`;D7PuZWr_P#m&p}Yy;Yx{7}>K-b$p6qNmlHS-GsX-kuRtE4}TbubnTxX1kZD$ACMfI$@o#rmwD-cdj?|k+5>r|HU$}1)r zM?y&Cf1cDO^hm(7DIkN9>COdnayDPW?2TJRa)MWwV0FRtt0L+YCNN2HgN*eb=a0g% z4z7=o4ikU?IO$0quG{ilaJ=M?UZS&fWhbdz(dM!@(^HNx#(s(h(e>#~n^2LZB~C#c z@r-7$r`CcwfHraIRql0?PcCtSROgye!nW>OsJ6M#&#FX;^PG^UIO)ex@6gpbuhlG_ zby$;c-1b3`RFRS{De3N*7{KTTX-NTT>6l1KHz+C6ASvCQ(%pjrquIy-W54HqpXWV} z{kdat+;v^&_dGvm9Uya%t|{MK{juu`NMEV5vCcEL@!@p+w=7BTatN6`ugs4G(@1BY z80kE9!s2%LH1V{a5hqt`8A42vuAr}6@ zqXSX{Z7&9vE@W$@CC((Fy6xL|MzzfDL!7_Lr$^{56aOes?|9BuaaDrS@vgB#&eL(@ z+dZ!a`Ud-E1@ipx_uAzGK)Vg8Gc0I6*?S7Sc6-}Z>u+p&HBO}6o(XD>p%2~VCznQh zz++`730Pj%U~8X?>|qAeolt4Pkc{ErO7Azxc^A;7wYJ5}8#(fu0t}&R=sh8d0wbfTaUi@|ym`N+ClCJOFIs%L5-Ed7syxpmHAn6{L3Y_L099_JlhBCv0Jn*b7A& zPlXA89eW00&x6-Q32ZGjP7$KCHeyRoW37o#|I1Y78 z)+z3XN!P2NXr7A~PEM)}!aV98hZ`%SuL4W+lTvSyNeQW5)}g~gsCfXp6eAXvQViSs zcw`{`A69%r7rAdS6AiFOZ&hp|j8U&^mGD?y#U_88Z9}k@|FluqknC#Gtc!Bh$Wfj< z`CfYttEp}JH~F$Esh!T?Yt=?Kt*YtzSagXjX)aTb0~9FCox#8qJLjAq{}ZA6OY>sP zlbRGSDwSdKTiBzegODD08?q_V8R%P67g22?H9hzSF3t_sD`U~hyV3C9Zr}Bhn^;zv zR?WZQJeh}vb)kaf);kW{-_S&+?<;8)sl<0XAE9eA(b)?HpRPX4CEPl?Pj&?N2QO?8 zcW6FZdnhW7O|>Luv34M@u`QqD^>O<{|F!UI;yU#3n{?e_{7zw!Q>Wr@ zE_~SySq_|2@crM>OK_rBzrRwNrGERP0HA5+ug{JZ?!A<%-`gV~+nDlN$4}oFn9N#I zcH6O@*Qh}vG(v1Z0a#2~xahYkga&Jo)Ka+K)ga+9x8-7nneq-Mz5bbiygFm^BA{T_ z2>*?;RjHtcNG!bS>6ccj?^go&zVWO5jI?t*vu7tiG#$K%aUy!~;}vji1eqgnzYz>q zhl8vjtHmZme}V9KjIK7<>a&C@wIzWiB@yzt^+!#?*zoFQHw`ZyBlCsV6Ij}RsV#<@ zunc!4|2>o87_4Zbd0#@qC+;tLCw5cpw}fVYqE4t){aD&lLs*@Y=6!`^{iQ=Ct-olH zzLWb_HBe2pzL!_NFC7ZxXzs+B+f2fu*xQpd7Su>-J{T|i;JC+i*h0e{H_TZ?-mo4ftu=t6y7#JdH zD^v}wX4tAT(z96l2{=^W9|Lkh#t*l?-Dc*PTcLU)u9bbP>q|t;t`DYvs7IRetx)=RHn` zA=c`1${i9Uet)EjCh@>F-Q@AqH<1$2qD1=81WTc8d&c=IrJ)31%-v-b=ZHY z#@c7GBJ4EX`c|x-cyAEem3%JqIVPR=!woFi54OcEDLW}tl5NFPPT{8ScAf^4Y8e@C zE-iiE;YZsQz+?dfJ>L@^+3o#c@W}U%xkKbYFOf)*x`JZ1H5Yh3)gZp!7`$GHt_C=T zw8e_3p6zD+=XI#)+Ze{pF33d7mli`7GjZ~XIoaHGk?!qIk2wgm3g{Vc_e0WY%?GMO zyPiALIA*41kd(1$gK0&RMTE}ZQt83J(hT2J2TaMT0-sCy=^Dd<7_7A~T{n8J1HjmIer71l1G1G1v)xdz9{!4$XswJF`Jq&VlHJSr0Wl9{yoOv+>wQi3Vv} z!-5bP^9p#Reo9y#+JX<6yxe$MFYce}6cgk)!vg<>gKY{KdnX?^TgyPggUGlZ`F|XV zk8yTPj{;q0OOT6wa~6Mr?;rblTp&xr;SCP~v~F9;<3g|M&vVO|Va_yCy(X3nw)p0u zU>*{)=uMsQyy4b|c7A%rj!sh4N zLNMVgb^yT#@OF1+VQfGE)#2VQP^x9di`vEDJ(45EzE0}dnrr$Rz4bg47=l?h1i3JF z(3|m0WeA^KXN*-3Uw^VHGdvGUXy)C1-hFJh(jC{#1cZVxl-3>%eeb5u?AROAv25%U z{h73lF!-oQ=vh=SeayfniLCC&N{5Tzw>~Sj-RXSI_)gY|>%N~Emns8Nlb*lupnz7> z5a(H#KZDRtt|vP%kij+j-uYytC!5jGkwRg(5@qA+#RDx*^6}SU9#D8SY!(z!eBeXQ zk;SHPJ;D?BCzLIg=|E?N&Has_(ev9Bh90H|m$Yc`C;3m6G$l#=v{J8ew>t0_SCS2c zlw(<$$&-?LppdwCb;GQroxWPC0yiBFH##GK$39N1JxSTj8!K`sib{4^C-|d7-l54! z;0g5>v6J%ea0!r*y>y;oUUT7_i8ikg}iGvS;BC(W~M zv>d)`zqyzheJ%j|4Dh=Abtusm`2B3}o#Dh3+%Xe$^=Q=s>?{^2fkdEN|r34i7&{N1gR3(fW6TUf|~=ENZJKIwrYN z|6#qX6KjDd7&U&e3L{IdmBb}`4MT`kbvi_mnm*T0`%(2;5)+6?3$ONzIA8H>#_Hzc z%q8*gPrx*|;un{`xUoHvW_r0L2tCE|3_tXT$HQ|$L@Urx;TEi7ZTDOya|mNYM!=J6 zpC~qV8k_G=w}hI>JcokU0*E>o`jff2Vym9Bn0JKWHSN{avG9x@x46DC>zAC|087|& z(p#@)qRW9z$yYRkMQC1qaV1{b=NgAhC}v?+XG5HrR^^`=)(9?Jr{!e6c@EE&lr1y& zXApN%qFWF|`c?!(G>J$=m28P8Aa#1zWH%V+Xjg6=PsS^UXsj2BxvQkQg_(>`#+<}- z#!o%0m@EIaGUZ4}5HikZIAA5EOL49q5>9UEvH8&pkw%!tD(jXFxlX~++6n3&IllDyKAxa3*q=r(3HU~=Ba!-03 zGikTJ^5pFhR8lV&r0`}kU7$KXy9e!C;K>TsZ4$UN95w%kwf!;hpGvWCU~|v@FE}SE zUOsKpq`t7t@1Mq5hC6J{dTO!PzNK|=a%m=b8?B*|MzJ2>1&#b#*~yL8p9yAi^f_NK z-caYv5#=U6j^|AK1}_jy_+*_Zk}WP6M10+al=$kUiBgoO@_lo?cMf{I3UvAM_1og< z^-*Osm!G=tkQwp+OY4yh8OIDskb+Lwx8~#A!0k`OVZi#(Od>v=k9&^4Q1SR*Kw;dwQJMyW@zcJe?&+FCOH?5RN(t(SyfwZ&=bx3w%>kuB7jOqhJu24)px?(r~c7|ufV zk&glSr%gmdFtys9^FV5}7|(MQ=_xDb&h&av(6a}xBeGAP7Uh6|ZC!TN;|~tYN|YTM z|A!U%v1K|?P7E+&`GhHOoC$+vhjBm|_u08BHa|w|WDVSnwM%qFw>0>5vvcyPO}3o^ znobK-e1fVb8<02&ZnF4a3#4Xp0^j?Ue|e>=K2g73pRgPq>&m5~&XzLYm~woTz!3QA zAx`v@EUPnAfI>dct8vdvXJ$$B$s%tNfrCwwo{GWKK`|_ATw-&Ra z%9kjCYL4dmS!j>R7a!i2X61PEPeZ=q;VYOoA)BrzI#E4{^P{b+>b~a7;RZq$httA& zwPS5IGG~3#;sxHA5zg3Hn;qYQSY9nzwP4NE+t8E;p>h9CW+aD2wpE$s`Dkq}S#mpB z$?+}S^pIW_nC;Bg`)&okE02Zp%RzKn#Eko9kVtovVofJG7g(2S&9Tew zG*U!qmzv2tWaq*vTo_JZX@%{cos^`XX#!Ii7QO4eF> z6J`^WkC=maa-ff?>b5%M()F=ZFbFB0g1GOk8*a_`F0eIwVjG-TA6Hw<>>VdDta+nd zab?}zXuybzIEf5x2YqIpy{r_cQLs4n!ey`zj$<{C~*T=Fi z78)dNc;Hc3>8_h0ETb-tWKX#98Eev&_Z8 zSuFwwZ1V)X8qF~MLL*l(o@$boG)u);bYE$}V?>Z_NB5dYp?e6F@gJ6PjR4w0(vl^( zJ01f^-)D|3pw%BAt5KA0tnEN^1MUnoug6Z-v1AfhLWV)QW_oT5hRG}oskxCRR8ayP zw}CUwJ2i50U2jnf>*>YOKMmI2Tr~LI*|qO2pqa^{+^DK0hl2Cs9Nuqx>lrqAEYXuS z^Fr=W+GPB^er#*evErgT3((hpj>bK3nt2N(VfbZti!|k^mUnp6b~~|U zl5jL&3LUl}A9hjm1UI?!zDahI>zedu2@O(MkmKO2`;KemgB5#RlMGZf1Mal^jT)4zuG9uY<{j?wklt1q^>=E4l|xTc>T|y$7mM8I z3P#t*tqsn_I@kIi?f3_s0=cBOYo(7hJl3h3G%p5pKfH9)5YL8I!9q2kwEtw&QW$35 zxrZ2!D7>c{|Ds|4FOpXozX>kwKcc_lt=p&Ok=Q5 z!ucOoHF#YPH|>U-O@h&hrh+yrci_nGoI57|BbxbZsy6l zt>VBPRkpwVMsFRxjEJqH5yLlltdZn_4>~qLMZf6#D+w+u3}*1%3iVwy`Z#r%foqLv z6{DD>CJQtvC!aeM0r0-Xaqj~QB1-Zng0CPwe8ZrcypXS)p=I#@u<(4m&JG0yA;M1G zp#OCbnd31-rvAV9fhAM>m1ju$ZMJ3JKb=QIAmcaPK!)@{d4`*1LTT%L6;I3_*JpdY zOz`KV?n!!Gwj8j*kluv5mdZ)HJTaamJyG*OtrUBUjk2&etc&q7jzxz)p5Q3=O3M`Q z-zUXg>PdHuU!lZj0z2#YUsm;c>%(wWMf1m5>9RLiIuQ&sEd#o$4QoK5wdeVbXr)_1 zXd8R;u-LP#B}a$!<1YTWB_XBR49Gfi1E<`SakIdir0bUVDv;``u{L|2OTzpw0PkQ5Y zF^?17u5~sKt`(^5B5(=LWc-l(&>|K#2TTerz8fUx#CB4*q#0qY8nct(peXRgP`fb6 z<7zaQkLVqd>E-~(et7xu>t>=4(b+A7o|y9o#idb=9O{xTw2-f3)JTdiagUZU>RuiB z_~iKI=;=tZk?{$xhzR$o^Vb*MoH(2n{DOigiD zzDpTJ$Ru`OQeG5Cy2G4RbsF1eXOG8Rz9+u%#^(f^)xh#Pn`Hadi>mDBX%~o0;Nn$Ku2;TeK1x?!oN&!NEsC^fT^K3|Sd2AcE)*-Vv1Tj{&jL9O0X zqQ5MH>=RaSddc=NFQ=-+Rl-gq4$1%d2!dl(pTtBc(LWD({`w*s(+GsStMU-w$uK5_ zUUf)GA8md5ia^38HCsyehDh>S`rkzKrcE<9v0P!;$BYy@hC^YZrt`LYCCr2cHN@sD zwiVv3nZQWgG8L%4b=2otYGfK4x^5E}ImV_@D6L#r!Sjd^uY*1q8$jsEGmT-eN8WVR z-`FDQ#VEARKg>{QzsI7y8akG__aY%OeT%4oX*6$s%zdArx_G+jRxi22`U55-gAaG{ zgC(8IQ4S<@f4rcVDbn|sub5Kj04!I#b9CBT`_IuaIq?J^mtwtO#llim@GXH89;6af zKGN?uN=aju|G1j}{Yc|nrP)7^6N2MniMJPz)j|`+zhR9`?sgqJ-FrB0;#Z#pVk4Eb zbI*GFZ@rs77k{YF&MlX&moJrKFw7;{myfI5Hy4;XX^0~?i@;1EanmyvJR;)1^xURJ zYM{B>7Bh7Rjn zA>uAP6HH+n#ruqKbwZ9#7IurNPkJ}y*~Qe(`S0o=I zyeMo>xTV3RQcavxPHii%(3|0bbX)OHC(HJJ2Kl4wJDO_qxq7AwuA&t6<+Am~MDw`o z(!lITWw(4tfh|aUFqv;FyhmORfxZiR_?tZ>99YM%!KS#s@80+9Y0wxT7dN`T03*ot z8W(ZPO8DD1m-e=NZ9j!%*g?YNr7#!t_L+A~MgYuk+ez19JMY;+YazEpp>%p4{pe2^ zH-pAzZ%IRA^vr?`-FsCprdFU<37Z1@>o5LCrTEk*=60?F?&BVl6{BfIMeCQdwfDsR z8t#8BM3#doFe9|p159}9zMBKVd;q0LSM!6Gkt(E^%swv3WP)x*kX2@i=h0h+bwLrn z#bX~tfQOR%y|u4CjpDqn^Jigc>#Rx7+r=pmtypD%+G1hEk-6Zf5^(feY<&Xu59}9g zb~rxk=Lo{pq-?v+VB};Uj|Dx-Nf*iZ*Sch7eYO8$X>&>KiG((GThgJleRY*z^|Z*J zn9ydsl4~#74l73;bqz0W{m|?b=LT9YDQmh9O7AY`Fy7E?Q1k{X;!-lR*dHN^)XYsI z5*56%)<{c^PCe3wr{1P!H#q6rPwSu}@2QCeA|-M`C=il=bqHkF96|U;RFy4I`u6+3LXs4|q^~4j(K;yB(ys8TF9NX^#AhdbF4lgZGfZaK!hZ}NHTO^ChtMsi zN$vPYpUKk&1^}z-a=LK=K)8z8o6Ya)2f*rEKV|1{eIBP#^N@a3fid%Y)?2Uq!AJ#? zp&k(2ecgHWXZ;<a9Ia)N7LWcIOO!?ItxQ65eTR13YV_JwjidQ`P zrhb=6_PRSXH*U)6HX_kSe;@LK`2%SdH4`8ii~gotDtL2FKszS27fQu*kSX8=PVy~!A9Jx}F}n**g?mhW}V->$x|R8@RXz>Q|g{ zC1R5QPG6sUKZtxo>kj+2(S7gUqCePnyS`54>D6;m+AMUx&yKd=h`zHED$Crtw$_yS9&&{!HTo`F;&1W2UJ_5n5YG8rr4 z>2w-1zhKbc*ja&4#jj~Z1l>LK*}!K0oU^wJ#XEVtyX}wyM*ZC~3+R#lx~-JCdeT^q zc2l!d0Z~7QEOf>o@K10MhTE5~HlK9z?$d6!?ppnwvv%;sg;ac2#s{rea?;Gi!C!P$ z-HQ})%H>SYQUwM+eS;NM1_SNnc0A^vh&yqM(2(q zaU7p_Vem4F;qh}l{}b&?52wGI2}>;qNGziMt7sV#Ce-?*mU|=Xc-dZXf4S?y->>vq@CCbqTF}aXkLTKPf0bvA(V~XFL`A zjy|_*3w+-=I@a(A<9f?=MPYNiUMil#tH-^b z@q$7Vj#V9ut3HFh`S?HJUAPdTjuT~1aE$Le%O5w*Y$-{x&o#Gl9VfDQe~Y+! z(REsUGx?o;!{FvOZsmyK25hTg?OZP^w%7f~RR2c5W^0+NN~Y9Ol;T$^-EL>yZbDN^ zTw9`?d^obUf{bqlTJtg*WrE0f6`Gm0a}}C!9ukKbc>J9wp@!@O17>+zHTpc%7dS7v z;}ktlFhM+f|6#3OTxUPzp@k56{atwNhTB4AN39!ae1DAP6OTk2&l8{MXeg{cHzZo( z^v2h?x?qZfID=NLuUm7~$+r}vF?5KzqR^Zd22p|4r{jSrfwkw&Bm0t-Okox^H{Swz zK-V|@Cz=4Qb?-K=@Y9GB9Z+}g>!->0FfSibTCk&0tS3o<_WZ>eunS&Y<-L*v?2UWpQ|gi-n*gK1?N?H-N2;)s8SzhtHemK zd=*ujoh?Q(A@xc`_^dKMqemUJ-(HVs>1DgnK+{1IlSrqhVMSaU*J0$TISu}7t9A0O z5|hGzG9L5V*?$iw5^8|!xQCIujGm`}&GxIV{8=*)gEPyvBWm!}ciLhwWDe1M4 zfda>r?*uk=RRYwfPv9A5#GC{~)yZ$>D88~+K)-;a+!)7?HF815e%^$iG|&m%oyS9z z;JV0j`uJ^=#I-&w6u`PDCxJ3WTr){iT;*I*+#hx7$=6M*82XlF;dDz+fueUs#WFf54E6xO2R%rf2#?!IwVlOL}y z6rAHrOB@uFCT2WJ_To$wT2SHEuGRZ8C+*+@=5Lr@gJqWG-r=7^i=fD&0D7n>(?x0BnDk&H8(l;+gxL%170f>E?=`pF*)VKPI_<`P*)*Rv{KA z7Tz!6%=%cvz(uyw<9=QgDMcd{M`W?dJ>DKkzw=pQiV%B?2{#XG{>9pLJ&$QwU zbws?RbVAFd6jQC~RDxlq?d^trEop;OAx=0QJuiLqE%;JVR#2uj}dT`P?X!u)Y3$ zUw`bb5f3iw69V07-xyd)beFx^Tdi4)=st#d@iMDJmys697oL-M#wu7r=a(!VtGNI6 z<99q$DWZHK3r18y|0W{K_eLM--f%BCccmKB<6;3xfvZf`i%+-F)n=2l%{I8hjgFf_ z&PRXnGL&2Kv?%^@(hWp>1G9e`OS)gowmswG$ZR!~@^vK~ke5LkFv(8r^2*~sru7n> z?O8~utN1W7s%>F>Z;fiU4}wrTCPbh74- zh(-L?E4_XLy@EDxP~~if`1gA{3}FT$9r6NoW`tTT{jk0ICayl`xiGqJLoIrdB+uw% zne{)crVNJnL48gIW^FGEKjka1x*rC$Ek*)y7Cy!5ZO~Flk34@DqBtvm#mRtEXZjEC znfgHAs@m$m#blFa*78?IX)`R%gGLYRWZcqKQ-r4Iqh##r{XYrBPnvj3RCr9x z?vkWkGtaq=_Wzpj-{$d}$S7Ui01-GpgRQqOtn~JTvawe)^WDoRbP2IP1fzyn&(Vj( zIi~$B_68$6=@W90!46k{byLa$2KBMHFYg`=!p=Uvxrl|$qLCkN{7Y7diN95UBhZl= zazB`9xqmQz7fVBLc}_{4?Qfty6g@NEC_!^y=*73nQI$(6Wk$~3Xs<1htWxViMPns} zUvhi(XI$~^bN;$SJ)9O3>Y=XS`LIBJcb1i~92iz;(_)D!>N!DOV9( z^o!&?H~^G5GBkC&yraIT?|`+; zfHN(mp|$!t*oMb`gCKzt%K1+h2eXQCwXX)T)#LM ze|)C(D`&vN50l$4$?c;Z{*q<4Fkg*ANn9P_7J8)poUu0fgBLg8?Sd4;l){2+D0xakZ+K+e zoaw3RzXQCwW%lCkkFkts9q=Dve8b>3v7F{|PPsvPwHMUoP?)3`Z6L$A@FP8ZkK!i6o@ZT3@_$EhZ-Hm6Tt#H-{8@F8c+FtHP%j#mH^ii}hG4Jv&X6wpcSB_Ki7L>(DP!b% zz?>te=6vs1OzdAWh^Ac!V;|K(1L}YP3BaLoib+yI*XI_3`>yOD)a~z1mQ(*%4}gsS z&0GR)nGn)={?-w4=h;>rj2lUuOY-@Dvx5R}nvOZ4wpL2xW#^z{P9O98cFT*+JZp1q zV+KRR7Y^Q!mZ*nt=bkdBe2!w_=Hf_F7y=1&yl2t*m5c9DkY8y{F2OjIKhoUB-fEJSaFk%5ln+qq{#3?@B$A^2kv28IAl=vVE@T(3w86I@BmxjZU`(UbWV;woZuO(C%U!vgKhIZM;y zZGDBtUm2SeC1y9`9Y?3(jRqfq$xSUdGw%0>q@Zq2Oa^oLZm)yBUYmEw%pB^rcnqqQ zF<@=CoV@WQr+R{4Y!ZFex3#_gk6WE@g8Vvp2JbIIcQ8Ha{K8jPo9O3G-d|;S2JeVa{CZfTc{P`g0`aDNBbY8=MGFWAjJ$B{#uk^q?ECBLAt13vUi>?%%(x|G*lm%e@dU)LJV=zCEHnhm)`M%Ma`pWEOuXk&JA-~a z9U0}dEIJNRZmCayEdNH0GTDYNeZ}9st;^FqZ<{3Sq-1lX)FDEPhSEAN!{gK-T*Vr-V?8i)wg@9)Uj&Q*v*Z;Y;IDLlky!EbQl87r(J-n!UYm- zPD6%nP57FuK-@emeiw$r9A7{#1Y-*KPv|Fp5WY8VK?PKQx`A}g&fHC$`ieNQ7;+;c z=y|_ym^^oHRwjZwS&ss0U1kPVFePv%=CAFKP6<0P;?AE;CcO~ZRdtdssU`Y38JKcp zw)E`b)tG6AV!tBGKfKW43RTte(!0 z@d>glEbzxzPz0iSoV|h`dOj_**8W2A?ho(uu0#-KUCB`eq?+bzN4kU{n%b(Hd ztg}+Xah&li7Zad2NR3s6;8@%vjJZ#B({-WQ$C+bGh5>B^o| z{hub-Tt^@&_Miv=yMX_cZb1Zs7FiHK&i%XdX+wUR?`I$tOh1&5iROh*fZ~C(5j@_Q zgFS`SL*l~ZUMU#ZU($VrRJnHiAQ{KUL(>a-hhYKrk z5}kG~J{S9P1t>e=qx^GFrpDE$V_<=yae6*LfMi+uUjgz60-)m-uMY zu|4YzJ?};LL4;^kq?qVrtnu%+MNg#-8>y)5HfUV*zLDMnEtF*JGgGW7vC@k-(}$(I3|JA69v%6LxHb;m_#-jaVl|GWA|Z zWLhiO8{+e2A#xn~OxHqY7)j&BtA;qQz^63GYETxhPqk2N|8p;Tyg1X`j$7o zZncMt-#W#O@BZIbjj+%f`HaSFyh- zdj)Rd4H?!mtJ_lI#Z!ud*pNI!RdgFG^;;}M4CZCuGR&_=#As!{@)>NK8?NN;HFO%( zD#L9=C_cKlgFoMuiX^v_%-&?~S0X2x_F`azU47@||Dd$Yv2*)n-qovZ9z(n}`CU&| zWb2rT5E=2pmu74C%6FTc=0T^VWs>HPS^+gbBw*cj-QB0IHGv(m3T=oB59^|ke>r$$ zZSBILDnciegr{--%Z48kS;oegmV&(R4W?9$BN<{OZHnLMGp=`Fu%=e4tH6nmR%v3* zeg7p_6yy7 z>!WC%E)@9Lzi22}znxiqV<7&D(7>`YHWu7HsCQW%({CAE6zWTF{=(NkUUqDyXHcg} z|6fXDoJ0NtVG@$+sk=OxeqEp$uiDF7(=xsAcSSbGYxK$94$DXOo?$4aO5d@q>#~nX z)n0^S@j{0D~tn*hT zFgc+<)}E5zV8;7zfJwjrFzErsEJ()}1whyB-?YzeXWGy5xrosioPJV}H35xcki|1QKU0-h7}(eYn4H+ftq!GFS>Vs1Q+i$n-`}W|f%)N49BDJFWL_a?XJVd*!U=u2WZH1kVuT(;P>yQtgRX zrHWdmMOV^M1FE_FdHk#^Dh4qzF~*zTCrlZEN{FLwMwuoa80hgqX8G}dSSCJ>#jo)T zVg(sjyK?*>-vI}qz_-k}!^2?eIWcv{k1f}4z2I+1wCm+*l+5FjEPnlmMTNl*BtR7) zsdSN;oO>PLw~7f`KW&|;pH_rH_;6}JdgT9Hq9I%%vY}d>Z_`J`mqDMlbl6!8>E`L%wc%kT(PD(`ip+QzM&4rL%T6SpzvrK@Wita55eFp$tr<<+uz~uO6%&vClrR^ zNl90eTNWVRA+cUDdOO<1<5}p(yZK@kFQ*NcV64` zIRjJ!vxgEi;NFsbNcfek<^2y4)c6zyb9+qx>|yw5p#*~BQ=iCP^*USD)Qy6DR<1d6 z4s%gLu7*DF6xK>44JDuYa2UE$Lt!PqQ!FtOskT9HYx7+d*_g!O z>Ys8nT$Rz=^{kWG_i&^^fG7~O;1ytGcW%=@EBdn}j>$3Ibg&Sw2cN-BskHEdcT?zL zo~{X9Vy+WYAJPWI{o*kr0;R8W=p&~q{)SL zju84_rd2fSTrdOXU+`n3MuKO6e}?cn?{UOZL~Q#v1<{0bqcriIr$zC;eSYF0zt_;l z^9^xt^%k=3jrYI;YIc-NpML&mLo^d#NPLXsxQgz-N;tXL$Ahd=9QKwreEYVaeY^Y9DYn?4 z`5fhcuwE%r(ooef7aC1KKoOn!sa9@ABQ?u zB;GJ|u>|S_@Fn`4Wq`v>&yLg#+j*VTM=$ath0yIktn(@BDxr{+HG8>(Tv6nFT1nJP zq(EPghEZYe+C%G#QK!+3ml`7zXEH%IvsGuX@0qm}c+N#t$;bQTxBsSo?Y%b%ap+jh zVLUBq7;C*%vdZZlyKRMYm#{^YxeUTJ8eJLC=k4Q7JNnVAd!m8EXxy_a2+%CtJV7HLQ$bPWk{9v`h7pxe2Z;m=?hsne*D+i#}6SGI^SGP8g-bxfzE?-+)hiKepNW9U<& zf~)l$|&GkL=oGET;Ew(hZXl54 zsVK(}L5&h8RK;lRHqa9^kCT&o;eqszW_3H_ClV6!aVqfIc-Cq*2rl%K@^ovZ(%0(e=q$pf8U^p5d_Oudnmv$(udM z3v+-JUOQ$b;!I)5&u+K-72AkB9tZF}Jyfo4*~JL-=0zI!&|)fR?t}i>`WUU&Tpd2qb!1 z^JcpGCQ)2@X!6bNXO*tMfQtGOM8AY^P5cIf8UL4gHj-wx)(FWSYS6$C<=J7P>4&2E zKhtARMxS`T>q(pMJTRaagq}-^p*Ep;%T8NeQSE$To#xPFCW0_ z?#V$BkAo#BG<&-!Qu$Kq;)5dL7*^NVt>N@kYA!s(fOqf8r0$4C%NwXxr@5j0XL=}H zxzw1EJhVHS><6c*92;rzGT;^J{um0_Jm4#TS!X*T^}PKxpA#~qCWO%8{+Cj$6;d{L z#6;?;rXY(}|B#{EanlGJBI~3kh7JiE_-yXHtTW{>RtB3JwYae6;>Pa|RjfhPLhinh zbC_K1gOffxyPc}atEu|rnO*UIbn>(Hl4?hRlLaw{_6D7H{;h*4uOp*vM9qMeisYZ{wF>7tvQI|$*`MV@1>p0;j7e#A}1BCDHdjQ;$4bD5vUW4a}?e<~|U zatIg%n@p+&f8H!3eQ;ZPCrGNBwbIpLV0zGo zKqOftJK*lBO1iS5j6rm^V^;j+r0C01e<|lKYpmz$Oo{R?|C5j1#NEUp*;l2Nh2)lh zc}Q>G#6Oe1ygoCWn5H2ETw=#E02O824O|}|&ZBZ1PhcwK0r$~|Gv`4DEC2HK`~7&P zJS^u(()hF}YZt$k;Xka&tMWzZy6(up9KFD*fB;gIo=3}PiE##1&y7AT_8d2V&spSU zsVr(lGVw7!i~kerQcT*|kvtzuqhBda$id=4-o-+B0v%u7$uO@nmLx0N0&-Xt>3vQ%urUku?BnVmy* zSs*~uh5j)A@DPfAMhz1v`{oWDADPyhNRLqpW_EuPh(AIA#X*;JPk>@eJK!kkVSsGg z`(QlH_XccICNJpSzsdPsR_BF%k;(F@6-tc+I``1t<&&&U*Cq$63vRouBHo;%oV76A!z;zAB z%2_--dFkt_A8(Fxo$*J?21|IoVV%-!d}c7HA)|ME)^fm77d{k`?oHS=JK(2BMacBqnW$lVfA{RKdY?GB_M(Yu! zT_C^AMXoQD+%!;u0#Q@Dn9)GkX=95fow*$M8E`X0Wo}DZI{lvM6Tuu?+Ppd%hE*Hi zIrX+y72-}YyaRupb!|%9ebW70SdhY*W;LRqV zWpJOYzN{{67eo98NS=sp-EFRgL*KVz{gurwShjC+V+OGp8<3X;k{u*g1%)^Mv8X?- z8LOgsq`t1Yjjcg4dP0ukx@u~kABN-TLt}TRGo>hegc9Ae69gVo`T1mdfSr~xDtNjy z&dW*8d&3r}+{;o+tHFaDWajlzt{j;YBE_S&#+Jo8Y*Y>iKEG?zSC)yI^b84ITFgbi zOoo9Ux-d_U78k8_fx7YJcl^(f`J0}Dv7RPa2-@fV?w}VfEY4EmQ@5J2KYW;P7TkL! zIW$^mRoC$Jg-Q6w&qW%_6PuqGEu4!xMip7rc4FaxIUiJMT#p#-Z#hTdF!?lDz@zVf zZ0A4r?9?2wb z3310KIBO2c^eGhOg7yt|J)iz2dJOA#-Yi=yaHy|>c?Fm0uPPQc7~`TWI1{wE^atN) z{XHZKYg!pX2HiUXbpZ7O$Arknd#L>hIdoZcU2#+RU?of{c~ZZDEv&U6151_MdN~ys z6+bE%cDk2*XazxY^8S{fDMX+-lRcmno*t)w&#LMLPU7LZTGRaByPF$R6A}>;ETfXt zmcFdF5rB?e0qLOuNbV#k_L6?_gdu6Q&X=g5GBaUc>+Kgm9eZbvy`F1G-SZ7ISb2ti z)6Vg$P8)<;W|h5lSvM@nxZ3ll5PD{Qe1o1$eDyo}{8`QK=vP=pO1?|7F8U$E%$K66 zTKLmU*dXYyuICu0qOQnbjuFsz(bCt-E@nM#;r^B_KnmfPJ$AgJuh_L=FC=c`5vsCS zsT>8eh(45Jeb|aOZFf%Tt9UZl$3SAbW4P-gGG$N_*6y%f{m6;1+V`z_%k1fF2w!`+ zhe)>Ie&5oL;h5bTPv}sIwEPax_<7hhJn;o4R(5y(&5zlEjQ-?l-D^Be8aOCn_sqDH z4i$pr>=i$E2!4+2Y$`UX9|Twa5d(ir2%ah$R@BmX6bqBoUIE_50-U3~t;(!Z@*>9V zw%+BAD}f($@bJ%A>8s0ZIUlAiKKSBqS4RMBERK#XGyjRZ8S{77i$pztvq?%%pOG{K zdqs`?&5j{Gp>qpje=&2vxa0Q&{g*J|_kmT}V4=kSFm;x3O~3EkpVHDuw*t~2ARQBs zR*>%Q?vBwY-AG7C=R~@j(cKLLY1jZ^z~KAe?|yVY;OW?F`|P^Tcpt}ckg8$e|6%rq zqK|j2YJOl%_gkj``~o^?!ZO!%E@G%$uHnDNvUYUzOP{nNg;_jXk4W8(z@y701x9o0 z^HdT-ePU34cDmb!6bIY4M}9?>9{T*c2R7GVmE6B=8hoh%STcXIebI0&IGOhl!D(nA zOE10l-tC3R?UTb@|1_On#4aeICdhNQjR{pFsb*{uyM=#pQ_q09r({@JbQ1Nx4o6|5 zI&u(TM{(9+mE_^d0_I5)b{RW<;f8CAt@5As33XIe#)^QctLp_ zY{Q;<1`Va8#jhHfSg@#uIY=^orgN7K}zth>9B6`^2nR6SC5JrmN|;diR^#Hfu2lXI3e}L*C?ssAV;_o08TIY6yH4# zn}ZJbV0W_`AQo%tETYaoh!gUy&=p3$A7`^QJ{nS_lrfYrd!XDR9kk%Dv=j*>f~Qj# zNEpO(Xlk7;x1NZq9&f>u)S`^3txW`U#yzuA;oZd6D=}x27YL`_yhGG$N^V`W;n;C zY)xedui3$>3-8rN!s?laqJ}_zVJwQnR5)kdp_2sLz;t7aez%0}cf_H7*!X|st5c`w zgsMUbROJC@oM?N(g9W6>Pb!u*h6W`*VUa4%IL00A`D(u@kN8aae%SCQ60TUVE9b2x zq!&A``_>&jWnttV_}8*a(AjAztzL~P@#^H(A2ASjGXychCMR3gxSPZ=y1zC0k`GLO zSCvQO4C)Px=oK_YH31AjyM*_Bu+!b7L8VJBu85)X-}AM9rIbk&_=hJ*Uy`z0kfOr3 zU$(e*`lAl${HSv29R>=!D7~l51Z(2Lz0ahK*#T zX}v#c9+aoUR^Y}nZ?68s*T)J7rG$ARgLf`h#D{bKmhU}*pu(|l(zJUrFWen1**>Kz zQH>a_l3?z+rv`dBj!C`U@`FmlZR$(6Pbr>_=H7eT=Q`EDPKlCmW||j)fXNn)+nuB& zpYs0Ek&A@2&5d_bo(L6uzs-$1LynhCXskQcG5^1eRjP!eac8`VnqHP_X& z*xUjX81`ctz!mq;lryQ*75<``?xq|?smhm@>JeA;;lST8TfZ?ikja7}x(B9aqnOE% zn9s$?voMqw(+xa?oJge~hfsq0q=XuTAq-@7{kX*TjEA%g2>8P?+U}EpD%`ZUtg~gl z_9>r(JBd3WTtG4Z8IF>X{|ES8%VUZ0 zVU^jebfiCgGNLI-`qU6e+DTPeSKw=1hJ4!;%Qnh);v>wa~~kO-}F567lDRi*mJ zzI&)4s!Dv-p|4pBdXg*z;xzD>xcKG zxnkROi!C&M5&wK8)#q4h(kZ?8OGyI4A^(@9ZR+DNmhv)iSE5H?NPXh!dp~8+C*#8O z<%XXk^V&@IjxPQ}HEYdNqoxJ@T-V>=oKyo6vQ(WuCi5BkV|iUo9E+S{uT=}QRdJM* z`7*yZ`OH0GqkP?{kGIJ~_>VvY|HSW|&i|3;o@)#jm0Uq%mA-w>vx?-vLFbWPb8mbC z)*&tib_)gq)xrw?EGN3sKMsgZ4pI*nRjl>*&ywrM*N@GAs+zF&Y#fex`)=IVbsU6X zB|N`}zX^G_dfd&nVAci~`cdm%bh50o+r3VSVH57g*ZL#7+dO)WCb>^9qzdI-b}*gP zV_Sd`BoF_aRX*Sk_KS8F+h3FUMm-H^6W`AWTbnm{hjeQm&6xaRL} z9%{UxM0e|1e%l;m3&qDB>c*KveqAg1dYbvUUJY9<=Oe+V2wKa~keaQG`%NcZ9-*33 zm7LQ1yj*7#SU}=@vYiVo`YpnRsUICph@<@BSW}rVk|f@e?3a415_3CqO}su=kiw1{ z6(#pf?mA6gH8v{H4v(O?J>`LV6Mg}jmb;wo6XUcnwDr z`&LNRDOeVYh#Bn0+%@{(UcCW=&IdvP?(_!tbD%)3pjWm7K^V`-~ z?xs05@JYeX<5Q_nsWknq&Yb;1e$ZxPh!)mLL!ij@;-53`ttCSHQ3}bofOLOUHF($I z`BX$J^xp2FruxSUN0od@uZhg(#NO1G3aw_*lfbptP?rQ4cxNgCMz~WD?}e|k7K~{O zG+k({(D)owQFEfc#x-LZ*?yu0N)I zU#XhZ?C#8hxj=0I4lucf``)G@{HucA&36+ZF=qugu~b@GQ!b_E11vYuVm4i`#+j58 z=lRUeNS)B|2*;tZ@&nlfGO)21z;Dw`sNs#H7(3%P411ifp%-VK!>Bl$-!|CtRQ1?dUbQjE=>}8ZJuCXbpT2`JwS13C z(S@{1jgr3`rJ_>?V_&+q|E^n~4>T6|c@!QlpSWK>EGL94{Re3BiF=elyti+5b@kp0 zMCcvsygC@POlJK7a--d$Hwwj~Vmd|z+LqtCB8nbOf;515QY6Q4wgz_I??{C}G{NWy zmrdp@%IkOD(-}@B5;18_Xd(xwA;jVKitsMzZ2VCWNwC~ai143Vms=8aW-wghA^x&e z?(R-X@pA9vBmb3vEFap%uU#UHubjtzplD9r@_#w)at>LZrxmg}j=7fb2Giw#J>-sA zQg7kp<=p2_XOrBE6PKAO(;8|w=X2dfsLm!B;SAbkq?nuJs9o$#iTRAyJ0*g(d^PNN zz)WRsYVq86n?^Y-Bc;Z zeVS+Gemd-U+_Q2zn2d1Tdm1Wj2qUj8nTYG^b{;upe2C{HypFwh+?6~i`)N%YFbG>E zDvZwYPEg#Oo*2YSM0MY{1rhm-b2T9>cI{gQwT2spaSKsfbKWVslM{X{N>nV*ow_j`N(zg1fOsb%ir zhm}*jM`y)M|D{?0gUVWa>Lb}*<58U>)8T>6*kzG{vFtgC*kEkrPAH`3ep0P$b=sZ3 z8hvGOKjSpwf9f;f*N#Pl!->#(HxB~Q(^J%N3~VSSRJr%A2@l2}`cdc8uv+C4=loAB zN;{8X2XFjQGh)$bQgN6Y&WOanrvM$FR&Q)YUKe9Gk)vR#jQSG2jcJz_3T4lyEw^Ik zzo~Y6iMz$E`&OYem8^xL=XxBm{0&RPc`XWp+{gMa2kfJ!&|Xgk3nbbrwpx)e5np?u zp;Xl0$Up{vn6sUCJR+|fP?4CDd`Sj}`|VPJ1|w&M5lz~yF2K*I*73Y_T;#}_@~TT8 zOl_q4(c)DURu&^XBC#A5c}ekFfeJgBWIVod{|*hdm%zOO9t~S|r}bQiRs1_DPuz@r zksUTFazfx=spHKhTFVnNRP$+~eoz|VZSwK9HdsJci4@(3a>{>7;lL=PE-4t*1sY%5 z;>mj7(p4AEeON6_u#oLfX}$7K4$NvENTWw(e8t3|1@=K4&GX-222hJnuSBgX~omQFJ~CchdF0+SN;(c>b#Z558&3XHcz~Bl+<0YKqj=#jt0J zrnLTvvA*F9v?+9Wry|R~xA*e&8TivPOZTC&)od`O)k&L5vKKvnBAV1^}z|J|OR zRl+7FD2u_M+WFV75tgjvgtR1+P0;j#$);$Z?8tRl$yR)|qc3-j^1PzJ+i=}xe<<>< z(6Ob!gNQH(@Jb7l)*A@baL}1Jk22#Ib2-iDlwvDw*c4{vzH6Z+KylY#Rlc}|fHWnH z&70O79gV}bLnQFY+rXLngcor~GNyGHG0${5SMaXRIwhO=%R1V7&Jnt#WikuG*M^@Y zI>SroKVDwm|Eo4SEuyVY_nz5B`guEY+O`_{TXJ6WSB#i1KWG8bT?BDgk;?jcuXi>- z0|N5;RYNl7{9?3Q+E_>wH;9VyQOI)sqHurGSxW~TJ$#*}ld00Z<4uo;hk)TyRVJoA zq-d;S;kHVt;D<8cE9Fku$h+87Z|$Fp{G_72H&nlR-g^s5rO3+O6+Gj_%Ayi1Ug5ca z_m;3i`GsSomPPJLDth z;W4k*Wp2=*Q|U_GnQh*yq;V;mVeV+GVsGFf*;_6T{6g{#4^_S#I?F?zKCSDp*iCV& z4$=#Nbh;CK-UT%6Q0djLfn)LV-;kY>%L$$Q(&)W*qR5j&$xO!huaqn3fuy|Sue>xY z@*11%Pt@E)A9u*`jX<$pz6<+aUhGu+Y^st1#u=ACLLnRb`<{At>y)HFTrm9kd4L%K zeg=sIM+emu%>J8~(RH7Bsq@f9321hICj;ogR-vz5(9GPg zt{;--+B(ju|2oe>-h9qWYe9|nSY=kUZY>MA2-2k6 zglk?AkID!X7$E1L%Dk|WZWG|8l^e68wMybq5=45x(V7)qKo9>O3@eY{eR@mI5%AJ| zJP>u5hXNfqML%q>xejfwNq;gt+sKq5#cYumfug5+Lco%z!_)Ttex%EuEohuNAu+tL zw2uu3wG=TY|?*jC#?fv_qZ zBX8_vr7!V-VV%J%q-kER!=(R@>cmsw+|5i$MLaCznD$w_$=2$?je?05fZEy=9O!&F z7ZPsuCwV`#&&mj2xQ3`3-iif-iGK{do=q}0NeD8iA9&@or43ae6M>z) zb0?Y->0g$nmB#Fi*vNUw82?>_b|-(a7w+k_^&wTX@@6P0^1GkY>v)t+rBV41Fk+h^ zL>MWt7(yPT2a_t;1LB`%maXf|dK*e}QK%Q}TDz)Zn}?TUD8B7tud&#e`d@G`FhMhQ zlI#aMGvBD_&a0(N4i$6d)BuJ_QzW_YGrkKV`%y}z>UpLBrsgd5A7JPiJ6>|BeCOq) zRQ)f7(RbSja3Ouf-I!Pmpesbo*uZFIL$dUsV~EQ|M{!?+N{^OE`pEHdwtfB=i3{$X zAtTuG)HJsTrjvva=DpQLl@s3~6-WbwD_zS#QtUO8a*$H#D&g-_P0KKl(~oib7x@y3 zV!zWUYpXPh;&c(#+r|IRQNTKZEsplDnt!IT{{aNzjr6&~v*Xzb?tE^}&BHqs|nnQlNP5w<4+5^*>4%^GBho5b! zs)wRT)%DWgDv<$UnQiq7pMM3>f!F?QfrnOl@%>!h(% zJmGICOi!W^ysXy zK-^aQ?nSYM2wo>CBghZUG&Q4i*Ok(*x=*aiF`o!dc9HGRu+diEW|tRetvGUfY|4`D z=7UDM0lQoszclq#%i6FNQ%`wVuQu4sK0j$P;K8D-EgBem$&1aEZvQOd@&9mwGEn+7#vIZLlIGN&dkX?E7%tOp@#Hq&&R`{fz# zgamu!oVe`$>T6#6^9?!yLzUbgY`^F3>+AT2XW|MB$aGQZw`xV&V%g+$TIKiz)P4vk z@I}6SYU>&k{M#1b>GWcQQVFlNs=M#oyph9K#uvf{Es_lalm9e;#P#x_Y@c z9Vs{uquG1cRqgN~`KQy(^iw5`a$k-}?nF34Pkok~6FXyGSj07+L3gh2%TZ?^=8>uS zQHFENQaeZJ61CLgH04RT>Y;P1rGyaiJU>1Kp5f(55mpK=@N0FAF0rpAztg-* z)@~auzbI|-=b3Ur+nvQfJc)r;#&1)ZOHA<=*flUGEha3lbBep4QczJJNjTnU+UYIIN)pVMYndknIHZ^>4skS42TTO51Rcd!FrcyYJuj zqzGWse}KWun0yXn-@*2+>ODi9rc|7pNnQQy^lE486sowYkyqf!rfwbUUhb6X$fT$y zxO%z1#j(8Vl#5qM@n0yl#3WCp<4Wh4siFjdsU~rBrm=~;t`<~<80QQ4Pq~B|m*J|Y zUH${54g*z?n)$2a5JnYl7ZE0L%u3h4ipB8>8}z_|F5v$LigiI4bz3HKC2}m$CAh}+ zm2cMKrZ-q;NJJ$LuWnKv4T6Y$zXC16xa0?FlTJPbYKWUMhBfF;C+_d(;Sh4|8$#y= z(E3k%6}DZ(7{5M}foHpQ5Omeh<$m`)Zxn@MXY|0TF;G{2B2X4l4l5Y&5o@HUooy<5 zSBX2gL?Mo6r@8|$3ji2ypQbth!=I&4xouwz2l81L*ZSA)J@I>WH%VL9+!Pq0==7qd z0cW|lU5GcdU{)kwdDQe|1V^rA$A19k=N=P#!`p5IcL&q0AgqJ@g1H%Gk%RJEoZ_rC z7qQollFQdHYH7LPqx1a_kY!y5guhcgjXbqS?R))rpN%{XTzh*+g6sqDC0;;%;O~mc zdh2G|CG3b=u&t8MvG*@i-7w_|Djyx}g-r05ywoC5c;Sa0oUN0k9$RO%&z+zArYC-V`8`0zRW+d= z7gU4V9ubPeJa60KEd&3SKA-s{&vouTi8 zxju9X_V2OFl<}9GD5>N$GO^7?Htlz^s;!ONUF-jwf%DO-K%acSO}x*Rr%s^uC@YGt z7L<_Fd(5A$T7cscV%6M>F3Wjcf(xG^?L9Hx(`!kC)Zl31s9H!=so6Bqmnj}@+4uC3 z%ZG>OUmh6{rnd=--W}XpcHM89TKM`ZB&?yW@wDIms$bJuox6`5GR9k8j5Dkk_nEw= zpiV-eig7Y?-c(Lol&bHx&HhM5`=z#YwcX@m69sYw@lh#z#&ZxG_J*ea6+mkO(@=eY224;#v>2M$S2*FjzI2i%7V@nwm>Rmib=3x-EJe! zF+M0!Q}rz+Y`DRHx7HM@5Np=?r#*G);*Hbf!C1WZu+B*30NMyM{UWrTcNi4Arjt#c zWee9GBwxCn<q2eL4xDS8qiOF(;C-j%6hK@t<%1M;d%+*r*Ije3+wM%qG%%AsYSVzqtldi zX6a09gU)x9SCB-J^O3uX3mZk8>s;H(0=4JgD3e@xLkRz#3nN&6uU=rk;^Me+dbO_k ztB1q*FW-yc7#u(KURsLC0nPEN2a|O|8u`vj4OG}sUa7VWzlVKClQUhaigYTj?H@;( zLhAvSjgRQ?`Sk})o;QX+mOgw2P+=}IhKn8+fKT;(0~WYF8;N8+LTU3w(Ap#wphnGL zluUwJ;oDgV!_ks(K092s+$-6ho&?s_*15n_9u%cc9XydQWkJ~Fw)5vqhu?wJVm&x% z){^Q&0Qry6cekl#O!3woR{u^N#_Bb1#V$Cde-)e)vj8CJ*=u z0Jxi$TW#R|p4EmoyO#KxN}^s_!7CSPH{Qhl$o?V!osNl z-Ec(Oz(aS`USNW_)*nAx!*D#+Na864tQ#kQb5kOq$NQ zW;frY55_NLVIO=L-|ZPFhya&Eht^uY^IChoR`LwUf<I>hT4U#?;!Elu@m0>MHS(dPKKXP7=J&Qi)4}H4RPz)V3MA^( zhGXr<#02R=331ADD^4CHo2JUk*ZAztFX&#Rkk*fW+&`ZRF|x{6y21VVaw#y(ldhim zMRG~XX}I`J4dKcIL~9Xb`#(t}B3j+s+trb!+BjKwgPCNML?YD7NAAmJoE1Dzw@RUG zfapH z@sD)_Y3s`mI?3_noGI2F+vEa70N zUW>-pYb}5fzT<9GR$8$!NZD}lYQW_{MAGkY{b9K{rd&L4(Lwf?*iZ7SEgHFvh`Xc= zze&$IM{)IKY`9 z$ppl3J>(@prgNVXTZ=smDYrja)?iZ{Wh5aGWwew)ed(~;Tz`oS4xw;KF)JUe`FmZ~ zT{V)D%-F=@12d&V@aGUeWR%0n*4z5TnJsbiugD2u5*(2!4EM#zUz<}IY!%I$iQjo& zyaf_<_HEGjhO+q8ZgE4EQHgk>Ty=0TTvfd#kB`l&cF`bz$}xMCebNFWUv4?91SJSAq$*{@HJ z`Y5f<+GHn{E4W`tV!C!nt*zZv z;!i0*#zTFR(0`-fu`u4x}Wmi8^ z)o8LPO|}XSa{||F2^1{^dRTuSCXLs6w}bkf4K*Up1AJq>lzs(~*_E_1(fLnE>@zHt zwIT$-=)6KLWkL#o51Ufm+JQF$pD_ZIy{gz*8w|i*^g-2{rL0urnv;T@Y_%-cGHg6( zitdOVLh$|HKTY8&h9xOxe3UPAD=(jePiOvtBBb85B^RHreN~S)uj3cFd_-ZbYmhz= z$x~4#exQF%coBeS6#Mhg^J*b?r{%kKsAB$`Ijz^B+T~LUJsGh{{Zb5z5}b8@F)D3s z(Y|TDYw<<3E${yWXq~*+sGicgg}iO9jqlek;&97xS+4&nR-a*Q6JquAUOYzf%B~o% zhw~-&2|n5V_{KxVwB=p(Gfol}!|(ENTD-r`YN){jT%ZH&?bpKagU@8+xnQbO_!;M$ z`y-AYtHj?gYT&0Rp3|-7HyW1chdqMVUwsscSbWA@aR~G&L`2*pWIr7VrKBn_qCucK zLCwTb@!$V0t>1;^K) znuj)0AY1Dzfx7WQ2JoTtykFuZrAk0Zxafm!SLWBsZ$DFpOq8h`O7r)Pt@?OXmGBd{ zpTSF2YLyjcV$Y6j70~0 zb(>GEwZAvzYCknE-#>DWx!*fF77s-$cfGAoaQ)N(7Y`S0Y>Mpr zW7|G77Qdf-tkC6s9K^`>L5rRFmd0nMS8Sm^A4A4IawS6RkB=3f}bl`{Vv>U-LuoVI)dO>yq^thY}YG9;}z6@>?BH!_cN`@?+z!ZS-#T(aFMG?A3g7T#f72Kolsj9&OE`Thr#z&N=g}W)S zbs+uJq2$gJllH8H09ox;!Tu!Ocw>s|Mv;NM3--x-1>mwEw5JKDt!19aTTjI~@_Tka zrUnI`^cnV0Fa{_nMW1qMbyXj8vd{|3nlyg{rg zF_}7hDn(d>=(~1)zBZ2iT>G5n>=s{rm#L`Thx{hEWH`)}j)_hr0)1GH0`&QQuk)U34v?#Kb%_@8Mw|iR19PFmF$9+ z;>@e8KZKoG3pDWAjLVtm7jplEwgKEPRb^AsPK}_{M5)5I`5~IUWsPhbQIxEICl#?r zhb8S_jD^Ktul-!ZeH*d1u6^sdS`lW`bIg<@)1YaKI;JMjW{7Cjvf?r@a) z1_E?HaeUz9P7*^b!e~(blDIo`3cG52IFbj>ee{~(fXmg2QR$ra(eOD(_+?MwXwA67 zGnh5cUAwf=!zU%IC9UPAuGIoepwGN)`(%G5x=F8Z!ZYzq9KwU7LV#g}koK7V%LG$` zrX*ZMoXU?OQ1NLdp|d;gO9i$aT3WY(a8dNwznqT_Pi?@RJcjydBq>8LYag&lxUWmx z>JYtEM>(7 z(d=yl5V4p2^dhjy=cKBY&*t6pg5tvAdt1_r^ujMvVu}DuF8FS$dB>x}$`hM{EMBO= z;>Z42w!)X8ibr<8LdBC}fQ`Pxz^HloAYQ0S!CR?JUFSpDq%Yu~Juq5tuqzCRBn4?^ zTht>`c7swoY`f=8v+h<%##ypQlqKP{3HY6$EF#U zPbc#;L5_RdcuIed=o#CanHoLvLKI@}YfRC1IZoxB15T`tj|G)(w!0WQR^5xX&@W0u(;%1N! z2Lx(C$$*EY#q#s-utCEMSE}Q==%iO~yJeI%{7?s!Z=HAA-}i`0VC$6D0Z+`IA8)O% zHe0MhClnCR3@KG`GLnwEB{`Z zeYxPd4Ot}|7^g?P5)C+89R~cvYW)X(zX)<<2silBVYm^bHTO1^+~UYpM_wudCl4yw zb(##vLXxuH2F%u!8jBlsEwE_an>EbSX1ad7tapc)hWuNO854N5u90@n_sBk_tN8GW zvw~WkHt$#2t8>PeK;@wJ(i&B>@Z$(}MiRdV01w6j#0lQ?kpHI57JO3Xq7(M>M% z#s@E$W>3AW6nsQo@#v^zm~e@Z@gCUcZOoD_ym`oXsW+yEQOnP7PHqd0l_&U^bt-&= zGhceVFX>3?hD6byK1zg40x^JAgi+0==pg_i%bcMBm|h43Yv30ioFM|M))kRh*YMl`EC;)d`PbQ*SxN{k_V{iAU3;5+NF9;cn9mDFlv|MCvRfzktjBTy0s*EI_5 zIhRGH-`$zRX5;alJeI6C$kBhb1uI?9rX{V;k)5BR{&}J9c>#$9s+=ck2%hpj`>fkk zcwLdR9i`EUrMjHFB+nbA?`Xfr^9t1G(@eQSAPhE1rsU0MRsBX{DNkGP)k2hpv8Ty0 z->EmWm00`h7nE#cm~odewq^>9ya+oyJ)K#ugS;Fl?R@z$lDM&`KQ9QC6`Z`Cw! zIcOgmyCS+?@Z4fL)|EOe-0{b@#fs~D+XC7M8bi5Ssh_6vm!D(D4#8B9WPaU(2z~F} zyDZ)^NB*Eyv^x)u)(6j?0!@=z-UM8|GVIx*inn1REUO%~TtDe;T%{R%21DC^WH>6=Ywy9|0MlWFJb zh>IAZ^n&eHS;uk{XW+6vqjPpIO$<=E_FXytsJEd6k0@56k_gjEd+S-&p%F*DZZkzA+hp;O zvSuf*@i<2vr5}(sX|5Oh!Q#Q3Wu-c)N=LS@RFMNWU8*W?Zv8hn?<3q%kKE>u^@L#nwe}fCLa-A;q z*mFZ74AQQw_BcH>++96;U`yE|^R@4%__IY$Je2CD2OhlVOeQ(fqTEvl;$K{<>~7D=>cBfra?37uMJC?TPrK_APq6Z4xp zQ;u$GyZL=gL8b%le1C@gJ^0mtrt)O-+5nJP+wO7*>xaCOu9t?pgHr(k^8 z`X1K6Wx_Mn0{EFQNOSHVkcY|%#i@DO&+X_Ng^qT(derd0*lo^xZHgU4fI67kk@8Qu z5`<8~RbMD5>|YNR(gG2P3caGb@IKcOd9z>^*!0L;G<}Ju=Cb!uiIvHf3A3|jj{!os zDt;S+XxSlLH+{xo-5Y6=^fZb!1W_RQ=8y{waPpjZtDLzD_T1nEr1g)1j1!@!4ciFR z|LY7gLMT0cM%?WRpD#bMBUtqo4wr8t=Yu|G-mVvD$B$`z5wp|eX=(oRUPp7`bEFjp z>#hKDfS8N)Q07nKTWY&YI_J{P+YYJnmEAwe3&kdJeqs&rb2-Wy>dt}_P0Grr@)_qV zs7i&y_B>|iN%aIo5yXaQfzv1Ui&SVgvAW1(Qj^ToCFBR-B{Y(wmV@A%gP%k1GA_;zj(LDY}3n#hWLE?*ZVlw^R>K2 zhLTI|yKUzzjxSPL>oZT+9GTLoUn`YA6dR=P27Hl9RT?&snl^R~Cj0}k2Cu6Pf*O3! zoIv)#*n9~%Xo!KxFLR6C29{q2U+pU??Yu;bGb{^bj;qgzBl)WXs2u*{6R?dWgADqb zzLI!d?T2;i;cUA{F~mT+p^wUAb|&~W>FHZ)$b+JXV$_>iWL}X!MzztRB~ASYNMdL? zC?85%4OD?S_?4-sML$z6^EiFL7b7K+%*=Iqeb#){J!oG>ETJf{@XkYe)1(-6BQ`&^ zfvgKc#kS*sPD$&kTaPMmp7LMsQeanebL;1(E6?;z@r?miMDhGzI3u~BQeF+<&HKsV zq!%{+`=YScwYm%zcKKJMlHSH*yz`n@`J(++qEf&Y>hqel;#2M&4n|Wm5e-Mb_RpQE zf#z0obv$0{j@73Oe>PNHlDLbK`C?@X|6nqJ>_4WbCEx7M*gANP4i8%Ez!XQ<>X{EbV>{I(c9l>epHd#3O| zY$y031JP3~8V)pNXE{r^TTe+D6CyT#PMkGetIVY$`vv1$)U$5A)Eq`H1^T_0{HxJj_w}@r2m}{Uf{r*8W|4ypS zDdZN^_kPjvmJ}vR_bDfz@M-YkbStG83$PQBiPh!Di+qhv{P@a`0na~8bLT(6it0tz z!5?|e1IhOgFw|6-A@auL4CQo6Lh3%k_^%>8Wa*2b`RVI@Ki9gfw ziBxu#QrRFS&-~TCrHlWC92Fi3Fnm4o-O4qS5<%3WC)-HxNc-g*U0KFD2s-$jcQu-u-2`0 zA(4EM4*||`e8h7|lK~BTd~PuS6_LJAeWWo-M2^K6*5P)(ElAtk{Ow>f3y}JQ?plBN z&no`#xd8RNYQ|qnt{8M%_IZ?tJ~*Lgry{_)IR?(#*IrYnZZ&IzM~a-V{E-S|2Ou3U z;{2eQAMxAJiLmbfr|X3R5uDvW%sXL1e@cJ9Stv>-XvIXqd>Wi_bm|L*(0w0sCJcT3 z0#SQ|(y4NgZQ#K;^Gh#Gkppeo>r|yfI`{Zz?-TMrvfg;A7S`gFF5dmAD9kDB$RA8P z|BtD3ePyU@lr^hmA{pXhgux{7V##BouAqGA+eC8*<%!}_D-{=H=1QGol2m@Z`K z)FvRazQ(GfXTIPJ&Q;L<)@lS#0D|u32mHL8iPWAiFX+F^T&^Garai%ijJWoh28B4T zG#m2ibCZh{rbs`s^KQ_tSSqBvIfhU17nSme!?;l2)6 zaM=3kKc6v65@}U4oepxmM7+~I^L%tU<}D6c05&{F?Kx(BgE-fAVs2|UQ}$kg%{lp@ zfG94_M6Kn0;qx|B9)Sm*e3X@vY#9NILRg5EJp3Z-|Vg%%Iwf(^=K7 zE5f(EPq9Itx5%Cvl$|A&yjcvPpJGaP8+m4`jNn8>acymz%*8AiX?>PVn(lNcj~cZ7%ld z>l3(VXZ~`%Wv6)4v8TZ_P67ybRfyrrY$iz%*1^p$ux<@bB?2Di)%hLI^@xPF)*V?+Y*=oF5^ng8$|(P_MzYcLVhn$A zdR9Xap%ks(-5w_^hRlYH^}y^7 zq6;jY_te(2Zj-$7qU@3Eu)m}pF+dNM$0*MEh=aj_*>>tDrl(Yts-a-JKI`nYrDXY8 z6YTAPOm}P2ST~$*e>$+rnAOX%^1HX}Ceunx>oH2}_ckxbo2A_PIWTxxv^*}>+tD<_dMORoMEk|ooK-tmZeeeC8N52rhJC+!Z}5~R zB~oACWN7GCNnMy=sfArpXvT&{gMu3{0d!Gr;5w7>R2vqz_tHpEsZN(Y$`(22GrI;Z zubTS}*oP8ZpHrVHWqIlG&vF7v{<>P3(QEVOo8IGfe6Zu_zsn?w0!QHL9gVgk0F#xq zXRh6p*pFmhtLbUuDn>(5W|fxSkw7hyTKl?M#i>F&Ptaj0 zZKjrDsAE(9QyBMZ^lU0W8iR>gNJf)&IU-H3#jw)ZO)n@c7uA9eH-0dU&cXP8uj`qT zmNL&UopbR4S6otR(RnsK>X7P466^k4{adX}Zr2BmC z;zu>E)b>FN`vpjfxLpl7=sL|!1aLj%wYEK8q)RHhb$p$IBsFki`Jm zrQk$TcTCr|p=RDENhku|P(yHEhC&ymDE(B1!e9Pf*XgKF{Uk}QH>Zsy30x39yfjS| zdSIIgVVY!V^pKxZ_K>BG&yV^2{G;yGvFh1rF|1DW!ZeJYMWwN+KK=ZfG$As0;`VSw zx914mC`>bd}XBDV@G9idRLaBEU9<8xnf`u)Oe0pFZw*xWk8&R&Z#7 zVb?lwsp~p&`5(Z>S?d_oNRV*&8L_~9QOi6stY2g!U2r$6DwXjsv8Zh7JlaCmhYA47 z%R$8Shp@Tg1bv+8)T(_8vURv!=D&ALO0m+|{#=Jj3Jx0GCn9Z6>9Dgl*@9gE?aHg_ zh&R$7h>z*a9#tvnKbht@A06FjW8GL1+>Xv*=I7GWI6OdUfl(XH5XaWeXbe%sWyXSZ zZktV+4uY@jxcwQ&c>%gJn$_I1f2VMZM!z_~yP3XaY6#~mO(|VMf-Ki6PJD*#CcbxI zynmJVBA;J8gly4hHM!K_eH*2DM-NradbJ!gpGGH4Z^yRIvFhii&4~@>j4ArV{!2HB zh<6?68B8vV=$9J1!zh@6=d7$>u#WujI72M|F%XXvWWTjY^G z7KaZ=qSBVV_c|(}tbvru@rG@O4{9;BwV$Ge?!y={Md-wB%G%fKEpp_)O5?q*WJ*-K zA9C3^zLGbP|0+5x?!HcUN~1adW5fqzMU4?A$`348tVByci@_lNE!a( z$g>YM;UL2bZS>oRPeuYbvZ>{0+@DZ10aVwbG1}g(^g8~H_T}Zdu|j_mUIqL%o4HC+ z^NmwJCR0edMHw`*BKMK^O0>G4C*#ZqL<6)vL>$0QO%T$~P&$=?xLUn9xur%QPIspe z%{G5pTC@20aCv#hT~rIwm~_sGIx%Mui={y%uq+4zVUROyGc(NksR-6F%=?tRN23>h zW`EZ7wp;5;Sym)m$1?uK%JK?BliC?CO0*UO*N#Ea9HFQ-TOsb5!f#!6uB2%O>~#cf zn~!U_+?K2+Y~3$>VNiV73Vs;bW3O4a6ucXKCv} zt|_@;=&9H|#nUz!%^JJ!;OiYyS0KZ|o5h6V{m%uh(sv7)kK8|}9dW*utG&yl&--at z(Vj?KC6OUT<}c_=V`$SN#fTW+Md{9Bpsulv23UWO zwe$W_ggvu4xS#v+N#f|YZykf1WrAtc|B-anVNJbn9|h@9kdC30w17x6QE35b$svd| zNROPNfaCpJ9X6(${A`B)6sVc;s9+2Utu7+une=i3HpGU9ZNXvVqV;2Lj%#73`B>>kpI2F^HG z2{l$nVp&>XuW6_7_Pxs?Am!nDh=g$0P?93(lu_96-g~qbIMO@HKYsY%v9OPJi$#%y zCDP2gl_s9Bh0+vsB~8Q6sww~OLWft?~ZRUZldL&kTx$0WDlYNh#UZ-mY7>v zao}D>y0nZhcy||D;mQ2?>%1%D#7(p04WoyTA9YS_^z$F`RopV5-X?NS!d4N6ZX9PR zMb@1Zz1qCa;yFLmMoZ$qmV?%)7|nMBwnFGdfsa^WxmTWZb=H7N)PnOaBE| z5h!)zhvQ1lN093gmY@rf>*zl>3LotAN-c!;2Ud#b6(82O!BI|K)AE60-)|YI5@Y3B zB8lVIEmdQ%PUTkb@_$mHhFI~oF*pbyJv!a0J1Z|=5Ljdp<>Xql=rYaHUKZeq zQsA2<0=09eQr$RL1ri7#H*IZfoeE{hZ<~$Fw_Z`VYdF7?Y+qZO?B17xB=z;A^y!cx zp+JU@gIQ_~=HvR3Lc#a~?6_Jz78T`}JxnM1vJU$nNh;3=zj9(eb3)f+N1I5E@T;q> zQamXZ-U%>i)g^)2+HEB!tKsu&tpE4v6Ij#=gMeuw`g`~1r5_K<&%@V4Z_`_GnP)>1 z1k4o+#~)05oU}fg@hGStJPn3891}&cNR?iPEfK?yKCxeSj%_p~UA=W}%==E-#|(8L z_wg`n*c$dNuyfr~x~_paW*cz!qd5~4L3DWA9iq5sR~C0k^i7*7;<0J-XA$x|y^Ek5 zB$1V61Vi9I3UQkId|__bbRw1HmE3i<2;X#xXJr+ z!RPyj%3Q!Fy)vr#kZ?C0;~9NrM0h>*1Mh%c(3;vW4;4lqvb!CsX39Jt8GfV~Ty**3 zhZ}YXWE^*RIRBg+lDn)h!dzBB3&S1~=|5bkfOHl_RH&T1!$zm_ zBqh4WsGHL`sGr>!^4rm(*FwYOclns2O^a~^V2f%nf)ACDA+=)uH0(MgpEbR3mMXlg zBdY((qTRxBR!6+lOECO*_*fq8_4j-{U*UZCj9jx|0j4L>V7~>y@RV@|#~0-Igv9gc z>_CKMDaYGdvqzkRTW)Qvn7apRh7W@7QSB6sv0c+a$D*&qV4YqtH{CjX_SLiS`DHHr zy`mwU?vGF@Jmd%GU%TkwMmh7s?$HM?A2QET?H(nI{_~0@-kvt-3CUP*m-E%cj;$@b zMaPSdV8(6@aNX1%qhJ?A8lZ<#h1YTg5h78^p63cxC(dT&r;XFm%3GpS@s{g#-J2Z(4$b%HduJbOM(a7y*dUkvd8BHCTb|^j03B`*?(B4di@_sZAIw{qtTDrYT>~F-hfJS zDHm?9HNkt7%a3z>it02K2QuHV9p;=BQJ9UVf5=lc1agYr2YDTf(rcf5w|N-Xq_HEI zN7uunH5LXND9g@OxbZ-n4=l3TEcltny8|RSI$ONNtEF4|={zub9W@V>`xonfqaF$43d77BVWrpZna4fc1einB(!Z%TSR zQ;4rHCzJcmOlxT&#`@YY{!Kj5v?hq3iZ{R?9h|xac)XPse!CuE=Z3$yygaXb%9L%U zHKrqcU|Cn2$kJ&i&yGoga06S!o3c&7e&VBlw!hJz-Smg=wb>Y`x&9-;WN!VIJ&it# zZ=iU`yJ0Y-aNe)twy}a275_S4w6~qVcx!QXS)zegYM@)~9X&YJA*MjmQSgf6PqJ|W zJXTraXxT;4dj7owZ)>8~m*I^xe?E4agpR!ah#M3UZm3YDLh;e`T&erc8)UUW?lTz( zvTN&r{@K9bjI-82u~fYAUuwkld4C*7;!!1dnB|mURtVHz{_V>R>H;kMdJ29#8rs7c zUl!a(^L*#!^Pd*a?iVYZx^Iz%gtyNO;}^OVy`l&VzkYu3;CJ&?>D}*Y?K#K`ba4#L z*BSPNHr(SXf6cr7B+qD2D_axydXp<1O86v`OWFSiX>Y z*}gsBqidL?(JYb?V)UF3?32HXLoRBKit3ck`62A~a>qD0I+xJ5A;{v?6vgEYt?z=r@ z4+Wc|OY=XgZ95`D*)hmk&y58epCB|;T!QR*1dCC2hr({+X_ZH54{2yeGQHL|T;DBu zDt;_5fAziUTLqyNgzNMa%eszZ%0Y1_2~2Ueobvv?$BE735V&h~!w>`MceZVFrP$`9{yYJG2kgw9V-+1J)aA&W1IuFQuWmxmG5PS|Hl%zM)fIxsgU~Bd(olI&raByZ~p|7tQ70#C2`b|Q{#PjIvSvupj05FBNLF^gstpH;2$p# zsRVG-qfkW6J}SbW>X}q4mL92+C9Klh+U`#J@3uKmi}D~>M-AwP?>0y8^zm<56kH~M$egDK5*m4)u$75u$lxaf4=VA8_(Tyy zycH2idI7})K^ElIuue-NOKmdV7+n^rIa-do!5GJV1R38C^NN`CxWpIm=?ES!`e&G5V8V@=cBX&PL{-J zkPNXikMhM$*u?UY@)mOVojgG_Zw9uy0$d=iRFwTc5_=Cp#%_wec7d6EYMC6IA9yfd z!Biov;au_Z}8&oCG0GAbAMPt07NKh8=$dHIE)B#cS|=0>K^~Lng!*kF zT5zQ_4krg+&3bJ7zDP4$nBqSXx`drZ1{!s&I-E)FEc{0zfM*{ncX5F=685f>VXKQ? zc)^Cf8E=nrXes&>(9>N9+@zXThWYSJ!9YHJBvlT4AB59)A>W8>Q+{4F&|p*-k!xbl zZ|U*uJ?@I!qIrR~x@HGXhuDYABqt=)9Asuz6gleOc@C^!}O~nMTPIu87KJ>nWZHzc;HNKIBR5 zzqA%S{up7Nh)>IF_(H;mDfwqpz^bC6HUne@!{y;b266@|wJi2$5(FwKx*4=_0e>2gY^ z;=JxGz9JFisla|9b#hmGtz|b8fD*@|4~p%;sw?Ye;Fky!7p@m?yn$r$>zgRxG^~DH z@4(}pWpkbPogLIcmM{4F#3CEuY%Sn4*8~wfsTkSjG~N}k%~GO|K@o;98vc)9JC%?B zh_&jqrEmVgT?e9HR_FK_@#apA=10WguTA7#qRe)Rc#mT8)2B|}?sica-t}PL(@G&u z9g}*chez>gMZZEBS_LQN*LlBzL~883;p0C=cRG!Y=C*pvt6hKI8W&%Wdmzu5 z2XiDR0|48(eFN{kc7k*2-lyrQ8f>b36c=Q@4$A0ji3$$e@8(Oa!Or|s& zh6DH%d}~QlRb&RYPD;`h6TXpYtKRkDenpQMp2Ahw@Af8E4GcZ8|M6QC(3+o?X3i@6 z7HgGrXYc|wE(shr8fK&XXg!F&`s1I&O4*FZcaPPi08!(B3bc$fH((`br0He1iZn$q z#vrE7{9V~DU(zSLEtUUB0AhEpoX7QF-qG6M<7DeN#PMmxU$lZ**^s}nld;_A{VB2d z;p8Vb+LY)|Ee#me>P#I6wzzVD>Ou7Lw79aG7wiova|FaK23pIxU>= zaQA0pA+PJYiqICCtc1Oe6VBij;QS)3#jOe2nk~qBl-vBZV)iHdMAE^FB$4D-c40C!iRsKFLrB4A8rHVNzGvEldM_%Xp zQH4u6uTa-P*vAXBnKaS4fVst_Ct;3-v}VKMh$^(Ze@WvqQF)wMKaadz2SIDUaH7?! zO2L#+N&A@F);A!re&%0Yuy{zG1CH;E8ECDMYoS@@d%&rCs|eTM?Zvfo4m?2kr&e#} zzVixv+sDtMX-;K zUexgIhE02r<{?A&eRw8oqFSfWO{6R^EsF!yoOc63D+sg`P|~r<#jEn+kuP9$Zt2bb zHDw&C84|Cv)HCCqs~6MbToZM?+#7Bp>A%*z*p1n?bA5LHa>K(od{SKQ(AQv3fVm;W z`3$VPia2Ewj*^iclzN2zoUA=nU}XHUIsC+=Tp*-e++*J%a{_G7umC=by-o`k16nh) z_B)|g6Rn(u&T>o$Mtra&^Rk6&`h8i06X!Ik-}^C`=Zv=<`Ad9J=y4LfF(&#mQ50&S z%B>c{5Lc%cJT}#RG*_ek4mG92!XnM}dgRm1k-2f()9Kx`WfkmJ9N!=Mwqo^f2N7$V z^D3T(Ts3?*IESciWO8|b@&3EkUVaVOM8i52!g&akcyS!+?B?~iP~A0f?EF9kW=%Fe z{a7Q*aozhz{(5SR>fn8wds-cDX-iw3II}M$5*er>J)?M(|D?B&7v#q#lnc#B%BY=*b9?5 z&J@aTMmK6$n50D5P%6+hUaMR^#>3BCaVnbySLVbK1gF3FTTc5kv&Wa$Kh8M(6tla+ zEtLOo=h}ZEy&DoCLr$y7e8GKXK?uipUTNd=cGvp%W!K7`dfNqJyXYS5+TKHEt)UCz zlPi!SKXhnr-}V@EPfFGca{=nYMSKeOc+G1sw}3bnEY~Y8)1AkM?JFpTJkBlg`n*Ga z@M=c#zDODS1j*-|jmH?5WlBO|PdiyBTKd%*_P5E-l15QY?8}W7k3pXXw{1UEZ^h@U z3d(U;u{T&~nNH!;GLD=dr`n||kTfLUNlFKT3c30XR(sy-;kGuucbX41(3-9kRG%g| zP-(E-D6qEb^o2TRJ!E3ng{Eo*XB`32$lWXT6kSMG{5bHEa>G^=mqqbQoY16=$x!ug zp^)EyYIjpL2AdO^Z>+npr7hi$Aib?AyGI0Fe?___SfQPw{SkBL7_SetyV22m^SMQF z$|mBTvLpqfL3P@HwFMu3BLDmSGkb670%%|R>SZYPb>Sg_4u?V?_H%|v(%UZ1?eILr zE1^?LGU=D{bWV`A$!nWuNAQcDYTfJx zb)iZeKI?&z1}+q4K8X$1GuA0y7q3B?=PU1#UIuuPm9v#-!c8Rh>B0&f9;)kMf8HG3 zw` z?Xy6SANxKuv`me^jQ-lIhf=IoE|{O){{wceyh#YdXRaMCc+USv(j}u~yV8fC^0ZvO zfPk7)1{a5Uf*h#>>E$aYD*@!cR;~*+ztXjZ3GB?)Q-377OS7OwU;HyHSrhsuSNR{5Lty~v)I(V zi?m9E0fHRvjXj9j^=1^1 znhMx>Vc3nIHmtXnG9FyKO)CcN0q`OV`yye*4{A z=N{~G#&#Y**2(z7?!7tt!1x8BWS6hRu0ILI4^Pjl8=Cq!UAkf{5vJnXFucm+Ggo3j zonNn^p}p>HwK?tZ0LJKAIQm_wk>p?(b-$s=ya~Zqm9WuIY#@QQsd>m4YOYF z_FK98K~f!Lffc33vCEo{#tmGaqxMT48g}n*OKX#Df!EE;p9VYo>ODS{$r>K2ryU!5 z!k^bfaN9!Qg&*;mUZ>Vk6)3o=*p8VF{L{@iQUU59h&SZ9&5DJ-YF|Gs38x%a5)oPn zR0a`I&M|U(UKl!^2REl*S~F{BHXQv{dv z8Y|aeZ`~F@Gb!nPY8k+2Q#stUmhaG7w`mzTqwaa}!;SGp-!vmiLkbtyc_P!c>3s*U zN+B*4sp1oKVlR9&GyuVGgTYukFcOW>RRrd;;2scd*+%bp^uAM{ysFJ0fGn+h9WR{R z6b0A(9(JOGf5l~;J)FrAu6>8nF{O2`ea*&c8VF-h#t{A^QLJ&XyHDzKvL4TbDs&n4 z>{PQ}zfJkGu7!NWtue8HH*nXo`6ziN)2EBLO_6dZtVP>i>cthG!d1@J=D_ZJg_`T8 zQHxIekxU*Cf>LPJQ1fK@$H$yJn8@JF#v^a9r@6V=J~KW9=$ z0SvU!r#FQ~hJ47q-@k+Y5=c{bFIKYpZW|I%CUCVh6kENBlIbgv?&cmiRBL>7oD5YU zX^>6g#>b1uNEfe$Wcf94ekBZlAPAwD)2keewGWicrW+ThJVdDapu3Yz#RTmKVFOpy zJ~uGTntjdtW7%&e^Re`DQk9DEemNi8;^V>YV+8tS*LIdk(zhstZKh}L?-?hz)u_$l z0Uux{Xxw-8*<(TDp~5=q9(sk*BF<-XSD+ znsn(ug$H{!5u>U$1xmij43@t09$z)$KBNq?@AhvQE-bYO8+F_ol;<|MI5cpZHIf6R z3%MM?`l>=vB4bk1ZVNvs9yAFZ3UjMA{yGrH*gceP8}_>v?lV~yP|-Tv)=SnfTOJ{y z=~zgL^=?cynvQ+-gks!qV4VkHS;kBI=?klOA34?dN$Jn09ZfcKyLs0NYx!RstGF`P zH%yL+LFLAOBo^yL%A2EZsuU!$?&hs*74lEC+WmIe^Gyt+v(V~W+tXSM4)J|H!L{$* zC-dbU!`6JUt`V2b#fxqJukrj&LJIs&C!np^PXm*ZycjN8EO`91blk0$g3=fGq0~-G zdIP{CJ;}i^eNnB!hdNq?_eWwmSIDH!fP9z+pWpFsnEAbH{ZH6V4)JESXzGI-H7WfN zVx3`YEiGFyH3b<+|L+MGXsrHI>gki&G%1P0beEX)P4V1*!c}-`w)n7UnyE6+lYWrp zT+`IO0DEeM?uzsD99o-CRA)E7xFSk_(^}3aQtII%Jyb1Gr9`e# zUGt4^yg#+U6nGq4y8om3!};eyA#P|%gXII0KK)r``0*5Te(K|q2fn*k?vBh|HB^!NwTfHJv@B9k!Bf&n$opQ>nC}89s&d=5v;_KgOkMJR zx?w^>q9#f?c0E)b?Z%nyB5u19@4!ZMLj6ZlZacP|2$OL?&};?)0OG#D+U90v#3eJU2(AALLAfk z4ySiqxC>(4wl33CkqIXeQ^D)C%E;QPUtKZ24Y4oRgP?h~13wuvl97#fl#=XQ0J1*q<-ZIrq|?Y2mHz8pHb4r7jZ?|{_XeL zf}{b88{{pu!te=Op*rYLub*A}VlMx&y@vQ-0DIUG?HLi+w=4X6f-M$q{ELNsOzC$M zsO7v_-`V}xd)f!ggP3j4#=rYBUi2+TrxEiYIE-S=Z37-nF*swgkYCSs6ZDIx)gJ>7 zzBEAFmX)%v5X&N+qYge2!T?RtMb?B%UaErV98a#WdCirC&&`|cje8-mywh!1wM2dW zY%`#W5m0I$54nq54|wC~u)zMe)jWM$B8B}<=7I{FZ!FgQ<>dMz z^_EzC<$DH_eLB3;u8oDGt+n-ws`3)X0M!pNarGqke*xKM6#Al)hEw+%HB22+rl{%V z_W?iGa>vQTmu&IM9`ZIlTX{$WBC=fCLr~jrPCI@e*n`Ze2uKv&h6L&BD1lT&hmBLy z)sQdP-i}R0aB{)&+-!U9K1tUZG6`KgGJtg#+SdFz6wHXA{)N`%0W);fl1F-RX)?-t z_7@YM5!UK+p>=N`lK2OI*SyMz?-{F?qGP%9 zqGaLOZ+Oq=75;C*zs-1^E&`bU0Tnb$emaD+_Elh{4@#ai88+dl0waYe&*Q5yE*M)L z(jZd3$HCWLuBVW{<2G&6NEZXT`K9l?*EO$&8vbr60P1k3$aljf!?lq02h;KE$xeS) zm{jlJSutMSqk(8}FZV5fJqp{Z$IZEMExc*c&)yBG8L9cb}K^4Nm$MxP%fx>)*&Pri)~w;*S6HM7s~2zC6I^OtIrzAWh=1UdBq*(VJl zK3au!u0+a843eqU7R_^`ubUzBwX0dqOxuvFw!~nX86+yRzecab@Cf6~3i8G3v+|mPN-UlgZvWJ5N)GywYOJsxu_C?d(ZMZx{_N}l zP!S>U<}aFYk%#-RE4_+Bw_m$Pl92^obk^#iqygs{GkS7#q?thMQ8<=)3)v93S9HZ1 z>VVa#0>Z}!EYD(>A2hp#1zf!d=_I@TN}cP=s83py^uTg~^B~>I#|=9?JW@K8tEQ~v z(Z7tqJcm^Q5Otis8Arzo%IyKIuvU z@30!3h;-JI7(YwG7Zwx16SAUn`7@y}v+lIna7vMSzKx9Ci?R96_=3vxSXvRchhRfG z!6P6wkf>IJQi3lal%Vl0+MZ1eWiKSiqCEUV5KLoxH|0%~-1kxI0W!?)+$B`(U@f4E zGz70AU<{$oj|}FS&hBC1pJYEct>XQVEKtuWlZbBp+9s|1+=D_NLn^=2K*tXkL93|# zW`jLlAmcq{vH%Z*fj$)CnAd$o{taR)Agk{HT)JVli9k_oIEL zo;$;oO*OnUap^a1UWDs5J+Tr1C;+mTrRWv>eNZ<%9>H7-oNz|FdGuI@=zOegwg+=m zoO1+mcHQ!6C5=Y^kaHqRsU}&IeI2y`sxjCz#EKHBj5i3!$P;43<; z+%k?iCMU?0nPPAHNBQC4zYEVp<}kBO^8YD!Jg?&qeri)ha6IIa`Qff{Xqd{!th<)} zzwgQqeA9>uWxhhZlR0)eMUUoZLH@3BjKiG)u5XMo!e)JxeIXH)B>=4N^grGV>%Lkk zzS@c*LQz>!F!EQNjf^En>oEPd&UJIjWn<3{6Zgc2t3cd*yAW{&Lh~Pq6O6$Mntg^y zMs+_npO3KEMsfHqLN2OrzuGV&F*Ie)s9Gg)Tdxk6!CJD9pFQR%W4Rx2Zkc=N24qzE z9c%WOt#X~;8$<8Cn@AMQ))%3c(5}s!XfO6`=IQ*7cO8C;CFY`ATsE8Fdeu?$qH+Dl z0^S^Cv?D=pdEex12{w7LICH2*uW3x340(Pa#h zU2$n(2`((-Nd2@c+t#BiCjEV=`8K6p$og%!DExQ1RE!I2P$w2cXp=hD9=|xLT={L? zKB1w))UX&jF8*rBr?@q*y*>K48-t}tJpB#+C+YhiNr7doQ(^{)b61EA)-#KVhfSH~ z*Pf56I_+-lIpO%a0KJmk3Ylx3c*J(Wv}~@aIB}4~hjLmXX6%(iTRiuHpV^ zRYSwy0%h3B5cZHF2!=fvk6e8)h^mXvflZyxctD-bd}dnzr`;C^_om$(@6GZL|SQY(}AFf-E7A*N4E;y{#KIxWiZ|nw1UQ%;~`?FqpFwBpO0U6sGX!I&LL1jSp{5COyXJ%#7v86~`ihmA^^!KD zrcA*|cHN;Hz&V7RK&ziAPJV>ltktbcz2YwB>svtYaGkPQ&UF{#J1HCGC*vyvRns!}yGaDo7)8>O*Kag#Gy znk!{S@t8PRLKMYwmO@oR4P_5n-2XuR%VwNRJQ_APo@s6%&ZYZRB3}#Jq_h1fgBPq_ zx2RHmy))=hQ9scqX8a#XqLJ>EFF^4Sb3g5DTt!p~fi4)^L+}vnLU;QkEI20gSvmBl zionUMveB~h8zW?+e4<`J_w$dWPvLc*OpF(i6R*GU92-y+wAS>fbB#`VZS8VxbLoJY zUwd9w3lk&f!`h^L&U_-DQ!3Okb7qS3(4X)S(JLHf7QdJ|34?b_Ll{G~*5YYKq{DfN zjA^&bZhyrMF(wz9`DMSmtZnjoprU{%T|AEgcxsS{H}&GOlit8+M*BaKev#b6{eG-T zC!%II8L#>Tk`CZNV1X-DJPph*+M^f#HSOfVT4Q| z=L82Q#$9(YpyQ`e5ZHQ|e3O5UW51vil=^5laIY}@C% zi0D<-hDtY8P6#RVPKt7lw<=EIoQrrqZT$Cw9VQnJk6zU0^{HC7CcUIs z(}es)aND~JcS0*kQC{ZOsn-qdpjBXW0Buqg|NAeT6HB0@#D!ko>Zv@7j>!ydkx37H zN+q%w+SXalZ3{|u9{MS06#K7u4k>!<0(%w>%ENDb?l38MKk)8ERNGQkD8ZDEoH!NH z-8;0fkQsUgd+ZFWuIvqCxs8QZs~~PsuOivvf5v*OcranbvvVGcnR+o}v# zd^<@I_um|4$xpZsd>Kdz(*F26m5gzj0)-DUxE_Px0R;DmqJ>z3A9jP4O~bHDAfl_{ z0)%qsj*ex>uZ)M?=tKhmcxnUI*JJFf*2{3gEcv-)-Y3#rVG}LXRR5^>DuvpDyswi8 za%*9_Mu2Z5Cx#^syYdXDnOMgy*Ey@HowngQFPf&&^A6|b!E^_dVQ@=p;X~Gbs3(ai+rOp+MRYCSMVMm`6CbtloDrc#$p?=zX-K_9nk!T8ORLU84w8R-s40m8M{SsD zRB#x^P1G6f2Q4JoE*FtIb+;?oVVOnkWC$w36XppDtETd!mDQV7jNa|24VXAe2slOm zT^AEM06&$1nL~e3J+01PP{kc)Da(pgVA!Ic-FBB~dP&N#Ux z>v~%gxd9Q~F$3QbLpqOt&f+O7g9fVO-B!nM4}ROWOe~^3+G(CGK6~bq_U?1`F#rAZ zWDRu_qefcozKR-jn^kwISXOiu_hU;;ox}H2 zlk;T!n__)`zR^dB`&L83*N?#Nm63aLxWeIzrLH!fq#oe9 z_KU{iGqG0PL`i(4_iOAS#1M%ul#i?Gvu~Eu3Pam5Qbp8o*;H?{)AStm9?}agcd?2%qv{5RR zo4=QH-mJ?q%kC7BVa^bd7%B%Z)hMJ@_0Z6vXTCd2{xbY1$C2LpY5JVEtIquLd2JYw z;qJBi%biBtmmfpksQ@}ZD*6cAy;u{COZk?MBE&70IvjV#&UJo6k%+1sG9lvpEm8;tmaLJNfTqIj& zxGXgi1$&fsXRp)feNNt!?N`kTO#hK&48m4{_H!^VVhzRy3|y1n!iikw|B*=Njjhy= zcWU)bxs-EnRi63y!mnJmO&oz^m#Q>_I9zdOF1+U>Fil749EHFmSjKS>13dU=sM?29 z-4vg6r2B+KdDZ-HmxNbcYZcCfpVE9JwhCSVaMg)DDm6dxTui1Zu))w*Q5f?Jknke*TD@M z;d|=qQio=|7XuZPVkE0?NKOAEarPp|!_aau;(l-{eEouvg&$t9u^#qtltx)t@aTo} zch2yaw7*F=JIhm*W6v%Z0pr_A*?vmJ{qd zV&Z<6b9tqm**!%@Wp;==K^7=bC%Dk=JEGshU2*Cllq@%LUjZT#2#a}cE zGbG-Xk3X}`BUcD}yZyS#JU3hC<%a&SThp91t_F6S@C03AQq^=wUjj97(PtJh>Q#PT z+~;HHuk!U_iBY^)KuXJ;O>F|V)!7u@?PqDLVt9E%<5RIOcIV1q%i&3+m)DI^zs}dQ zEGmuemO{#>3c(w-ULP(KDnI+_S?x8L>;2N&Qg0xRI-H$Z7as+wzBVKNlYjje^=z$@ z23nww^+)u0F+jVY&77>@bx>+3v4W|JdrecG8LUU=S_^Nx;@NL|ES*22cbh7TD$x0A zCY}GONR+=^b^fY)0`hFU589jgR(>9bRxWw9sLBJcD>W-CS;ephg&#P zL;>spknL!4EO#|GF5rJtijaq*7?nr0>+K_?DE=eZxNQD?C$_5}@6>i}t zHaA>15NY2(vvlJ<&an`rAU-b_g#r5ois8)0oxD%V{u5A=Lby7MlHXA+B*u$cj@^1b z|AaR)qlQ9c8YWejOdj!4VAr`WJ8nS6rr-355Z9xETSGYZ>l zM{&nD`;ph5~WX13Fxi^?pi=Y9+Y7q4N(-7~Fc$R=XOGysrKA_PUlx_p!IW_J*>3FR+P> zpoY*EFD3qvFMm5d2TItlQA^RhTio*B&67n_-NT*<0jH$5{!HC$2166^)5jBr7wN!a zksla7Q*ma5LVkbipT0;vDxcQ!ZSws9%o}%pTabP!(WQ@;ZFL*#kry{Sbs)}zD3+ee z)>lo@^7LvEz(FTJ<@OV>{yrQlR(yyHZ-Wk4*;>1WsnOg*Bq`s`q^2#XP`zS=E3xXO z(a*9lu-twe){m>aqFMvG^d@r)ez9$O0CqR{Sm+~Gm)za(Hw8eekEaNm?T-Gms)lfb ztb&KPj3cy^<1c44nGNrWXxD60A>CF!2|m@UufAJJxDbSqZkjL>XOR@^oDTJ~gY!XG z9JK!bk^D?OuAF00S9;%hm?s=HeybkwU5L2|Rl>B#_vd_fe7<1yIKsYm1`ut(i9TH! zwTBQVL6mPCcxz`e?u_jUVBnYAwsL1`d?Y5f)owMf!w|bK_3WrwC_Fh+wylguN}o9h zzn;$Aqq@=vz2Bg0waf1h`sd65x!l&J*q6m7PXrX!|GuHXJ15+EdIe0!n9(RRRR=qt zkO;-i(%~Yy#(Pcc$sxWrG*TtAWuxfr{_lqcXGM6m+JpL;`jq2azh6{cMvtnl0L~iE zsKx9T%#_S2!jf|WciY?j-0pQ(Xm7<6?UR*c1ElMx5(~JlouTvb4f<5h@PbOP8!Xbt zDm8J~%xL&{trzw}yEjz)e^MWOC-%?~mzcY@EALO~(E2ByNmsU{_-(9N=haq1H3Mqb zwx@pRQ;FtzdPG=IWg={~R@*n9IJCP~u0UwX@PG4LT0?cm8ryTNoUamR#*Fjz<4~W7 zNH7_{hb`)pB-brd{&TU<9-0=TIijKXx{1Ai!}b(Ah3StdBP4lq|Kr=}S}5KH4E=W@ zSu53tF{ipMziO2}|482(=o9@IDSk&Za$tQG&~Y=L9)9>KdGqEfCTEMh738Ks&;AVSiBUmtKJ2vod$@ZT)olePq3Rzq0QA zhwIFClnSbS2=VIVcXHn24RvA4QZv?OMPL7Epj3{wfG2BKTE9GHHYV@zHe(w|`~q8k z_rEeF$datBm~@g-&&FJY5iUbsF>hBOh0kR3PJ~=BRgSy(*@?cH7vq1`+scB8 zeNa5zTg#S1s78bOMg9aDpoq5({%1r^I`yv#te=V&S6=)CdX0 zyPg!lu8Wv@^vKfs9<2s5ZQDa=zs@-VFL0J}GkiImtT8_Hdz~`edOI+5Rw2>Ftm0;w zV}mH+bI~Vwon9c=5bO;63WtDzbKeTl=dV%ABkv0MpV;38g}|PA}lOyWhGjwbP%4~9crGj zE_KFaJNvQ4FDL6$wzlBjZ1E4G<|e9Rpi}=T zP%EX!zj15*%Xsui?C0UlL?A8fm#PSH2vKwXzLgns7@KzkXU$Od7tdrksh*HJa#P9w zVNLU0gADMpJrzbM-W*As0AsZeBN4SiaDo+%=Xk9$wl;2sb@c7|7Ypm>v$!A{M((QW z+r4RbJKgsbL=#hw8up8?=<%#)v~k_R=bfWYeU?8<0jt8>wyzA{Ve95}L$U*6je709 z9-~<_q0291ZYAm7>4gd5DBgz}rrMV}U>h$)r~?yJIM#t&7QdFmrpq=ilmv#0|EQV$ z)3#}k#!Hh6&QC$9i$OW{dToILNaOB0B-1W}fe**p1ByoNIuviur{%jlCEoRWD(3fD zpE+vsaSN^8X}Q5?qRIq82?q#ev2jna%hAec)rAs#VE3O_i=+P6Wq(GhTfK|DBMK(# z!fz8)rrPDgd-7Cg9(e#TeJ=+cu{_tLSwogQXHH9i6#Ur!6D_W);oiDAcSR@|XIqSO z=(@~J&2iyU6`3*KS6E`yWD@A&QhV-e7Poh9@_ zB$@rgOKr&r4ed%+V9B(4UkQYU;po~^abUeh*e%zF`Un16LPxw)l|N#i{{Hr|TATL# zfu%|}EW-jq%6tf>2?w0YI_<)Rp@C0u@Tae&mvex*QAnrywX7P2p-J6fK9nF35p7r;(lcp+r>A!F#RNk9@Syx|^GqqX&2H4^b{dUabj@o4?pd!kMqOqHQTQS`}=6zRu8|DK}3u z$ZQxEx6+o)pIty?IT6Cdcm&IQN)o2Cb6oXZWsLsiD#(uBJvg11@9HEA!2Y87_bid2#*iKHlk_1i03-l}>!``Gcd@-^YpiPz8;Bz?wr}6a$~n|Bqx3 zfqj`P6^H>*2!)4)6R)y}$-@rB$@~^jqCKc3AY1}JkxFnVm3QAV&e;N;O#J#4swa%- z3I22eQaLH_Z^;rZ?lw}$4@cm2(ES9aQgm!F3H-Jj2^Zqxy1BR(3|Ab3V_mnYk;c%vc@Y)W^)TCf|M8#=fc{^lk@Qc@vvXV8)y|UT3$1;Bz#X^u5DMe6}GTwiaeT zv~#ZC{K;}qafej_8j~!?8%4TFa%B<&lnnrRAp&{|`&-(}zIW>cLP#-YH-4v`QhuWP z0bhQlwhLoybtu3Ob-&;+nO78DP^TMA4WHS}ENon-m1!b_pJ`g|FrBp^4Pq4GaHbBH ze1KJ_&vJHoul0JVbq%vG?465{p&CiP@!@=6ymD9ycz3;!tGRXwD^;4YjVJAKKtJya zppLhI#S`IAueoQs)}aY&cnUk)tno~n<=6_&yTI09#u~7@PjP~Y&2F~!NRQ7IZ<~1M z79*3?2iTMljB?g*FnOYyJrHh*P}9{zYP{qtY^%7}w-I=F_#d?yGCTg)*c*CPo-ZP_ z7Y!^g2EnV1D_H#*B{*;~{1O<@OWR5ge7{+~dDGT79&S905M)0{6PJV1lOK-yGadx$fq06Q? z%|9LgTE%8$UO`?f%k*gE%-c?M8TUL(HObvv%TM+kex7_8XM+Ea@{;+B&1xz1%aVE83Ls&cfnr4%fjyk@CvEgEC)4`0K8qmr6PRN7s2+GteASQZvZe7>j3NY27GvO)^oyp zZN>ZBPQZi|OEF-uc0Kw3*ee+j(3K|*A?jFy@t*MiNMw0{k<4hIMlC+Igzy|yTW-jY zY|Yz2UD4QH@y}Lk3;SMm(|EB71k)qa0XLNRGw@w8Wg9f?^9sZN^>pR&P<`)zin3;u zWl|VLiDWCulKE&8#e`6VvXwpgSVG2O^= z>X0-eySH&544Wt=lkaxC?@T8!uNF%)O0!dm8^g6dEo~LW?-ZDt(lIt4m#bnO3UvdZ->)K|*#i<`Yi(+iUZik!4;OytyIT9Az*K^i?UHBL794R^4- zzgJ&8v!J$A+^z^{eGo`WPqSdmdoxx9KBoN*+h@@ETDs;}-1B04s4vkl_ny4ln=V*( z;<($I^63QHJ5bPq#oYc82YM_CR;wn~X{Nr`wR;SrpX3$St~Tfl&JGa91Df~CvLhN! zWfW7MS@jPNswekq`w6U9{nGf)_CtT?_z$e#P!B9!0N+TDYdEunpzB4P(mq13nfVb> zXvqvDX~s5^;Z%(1*V~;BQvHsTsrP*=Fq%l^=<=s_Fwztyrs@wY1|XHzob}1}LrRez ziVf~_RuaK$i=(%*j#XbX<4(zoaT}Ms#Vwu@%@O80SAb&5tx_XACEMB0Hp&K{X3cF) znY9Ba!3DR+xE-|}3;QhY4JGKHHp=tuYyF^*@1QgSI?hNT7fYfE8>RrUS|0)T%~`Y+ zWqHx}rqS_q9Ss(D|1Tt0&~;{QL*&^60K>}&2!$mQ$%@^maC-lv|B4wj(_ai8bwUPd*u^`g*XiHj5<2+Yiy_b z%(>RY9=omR;lt-GQ@EF}3;}caHM460)>L8zPoEdl@-O61&$?-o!Qh+NZVwZQnZ-4)kThLFXXz;ygMkX zz5CU~3u7Y_|AH=~(N|eT-x_<1br3+wsEQ!zf011BbOTPchU<^a4-YFAN22*0lnfL;N_uYTRWayo;=~0+w zhaEXE3X<-KR{=#oVz&YKEKRTVIGA}~^Hx6Tyb#W}q2x=+Z;=!_BDVHAK9fIKH?qW1 zejuO!5l*ZMryAbcN!p%FV^!Pw1SvCnfI`ab?@-ox_8UJ>+=A-jKFOS7m)XUiPpiw} z3ZIR4zm<+?o{zh3vEPwCtgZ6l!IT&iMdmpZXsB*$=s~ZNxVCI!{WUw{rb1q*-DL6v zkYK+mjZH$4*KA97$s+gQRea{4SLMpzG~S*J*GwhlXkZd-a4dt=%2Z2kw% zWZf^M@n_8ozVVM|!h19idLVFS2F3D^UVM}H_3UD;jPmzJHTbm6$lqvkI2R!DzUWX= zyL>`8#<>{?I=e4_#s{JbSr&BDOW4*a>3&FpDY%fx#LsH@Y*L2PXWT8E`EQh@;4Kku z>$P*z@5oz(HaMq$67iF5bdVX%chZDVH858(Wa92PGPT+ra#=i-6C=%kKe07L%L(uD z2I9*HuZHX(hX4n3m6#GNBrD+tiiH4OSB9GMK+dx6HD$j8R940$Rx5IXF!X=X@2ewB zPe5(Tfp1dsClqVwiGAOhEmC>s4Px*P{#eI!qpmSUOuv~f}7`mm$bkBrvl+ z2?^W6%!>2R@sT))bI=OnivU(HX-G3p(w)%gjWK_~8q%v`$kSV#sn9+QD+B}b6{i2% z4wUxqfYVP2)NFOh2V0g#V`NJ;SUPFJtMeBq<~2<>ls|r&^gVc$3pF)c_nR(I=SADT z!F+u7xl{b2q_*vW!&P{v2c@||@mV*%HGiST<@#SD(5yQoz+Emxeb@!{eX7q1qBpE#CY|MaF-#;p}DdZ2RiaLdT7 zrFOFPiF(IHMq&k6-6{S|4RR&t`FJ9%r$qV1-!^HBAB_&S7cqi!TUnN2P8`s4X?qFr zg@R|F*R6!wbWV{llr1ErpXsxd zegCjVV2~VpsM&EpR7ThP=?okl+Hw*g+Kv-yDH6)qel|9f;hn5 zgp;zC3ofF*dN~)~)8l!T()mwV@q5MAYFlI~{e{D>f`n}mW;=EZ6An1D9QyBIhcXw{ zVoux-?2vj2oY=TG-$lOtvBECB6jOekWd5j+p?3uJF6g02wYM=dp8&^5?Ld}UysXH` zSMP%eKh=b%X?NnlTMg&gaWzhi2y2H0k+SQ~`fBCG@3Lv^RNJ4>o!Q-&rC0Qj1L6bt z#K{mxQE$&;9x2&V{=kkq&;2*LD)k(&E8(npKES7Ce_F}4ow`(|)h{KbB@Vawb{Tz$ zGJHuI`?GVpWLflvmG2>srh=4OCy4t!Pc9BJRSNL|$cr6eI>91^vZ0d_y z6q#X5Lx)|#*v0L)Kd&#0Un(xxS0+xw>mTm0tKAQ=tr2Fnt6Us&~ zG1!o*QS|f)jz#|-uOa{$rNBL(oNECABFzO6Op#lr#!4*F0*Gg$q#inmB}%9Osl9wa zEN#!Q3*8JezbyHmte4t46nk^mmG7MP-;Mi5529uOk)ol?jes?R(OIJaCOglJgVAXk zx_XZV@K#YRjvy>6$wf4S9mW*X(cyDRnBkKxAbfK%n`L!xqnVzAP*d z3k#~gaGw*$pbY3}<_X(YZoFb}2c!rZFUDwEUViNI#sBt^Gzqj37pKqysc}D;<8n-C zpq%lx&|(M%zGd_HH`vQxxN|b8*&1^XwV&yEqc4)G=R3RBb?d^{Zz>mV^TteiG?UQh zG}?2OJoo8|*wd$H(4h_P*{(orA-3hb>OrTqwn`Vi&SsekjNeMMpSUUs}`Lrg;j!Dd%Lq@zDdi5?4?i#YrG{goQvo zBOO^7su~&ArLq4SCVJb6QDP0A=n}z6>vaou%(Yb}bSiKfx}HbT|60U2-#R1G0@k8M zLxVr7Dv0t3JpO!%J0**p24G=%6?Hb&D^*6YsYn0anD_N#4B*c&!}6x zq#}x3vjm}cv&Za-s`8`W>EzuNfG~%RV46|$(fJ_ubdwTMgQl^Vr5_)A{(3>k_oZzi z*GmF0bx<9Z2blKFejoy~;*eGSFl2S+1~P;ouP)A^IwDU%;u8dZp?uD-0jIvenYs7**i| zk5+PU=3Arj*MNmC?C0b9FB#H76lQS-d|)x}D|m0W2=;Zto0zgQw|l$si*A^o*6yG$ z_t1Ak>M(vy+ORqQ7CX_Cc>+y!(m+XWV7J#%pm>!BSY=prFp-{tTHccrXFT5vP-=;u zOPDFJEw|A{mK{riU-^U`h49Qg8Sv8&!5!;54(zDHvy8Sj5@w^gA@BP+#d1vKrUYEb zA{*rt$r66aIHC=oKwLey6s3q_ezT|p*Cj*Eb0>YaFFlClLaJ3YGk|!38FLcJ5|L*7 zgOrAdqL=J5?tuU9^f6z8c)hlf2G`oQ1%B2?QG$MpLlzlkg*l2IUrbo2j<_N_i%U;@ zfP$PMJ0ob{w!y`1bEy3f>_8CD>=V;FXe!&YHGuVKV%*a5HE+kwqf$PvK0f0=Y_IaL zsH2$c7Bk6_j%lx!RBqVOx+FK^pU?4lR=%;8zLBL`y{?(YB8G9rdPndVU#L(P;|+G% z?4m%trrYvbD&Y*tzNYgWy?(5=?>B9kCswKT@wN;kZQVPCSYy9{nOnnOM39xZoer)R z-F!9L*yHHx?d%t9RWO;OGYOx478vX~Z|uhW^0tl};uJ*M7W@k?V?da?$b8+xtr3p+V0>RTNG2lMOoc z-wXnL?VP#N{$XiGDkKSR=s6DPsz{E*hUCpZFqQ}2n#rb5lX2~+$#l3VON<_YZ^TWo z4mDP%EcTrtG7UO?AuMp*rB%WNK!8C)gark#LiE(70on2E0@>?j@a*T#!i4-ur zX=ZJ0zrzU*W0h9j1#<4k2!pu^A8t7uwzlZ0^AQshlMfgfbKA?ubNgoW8-P;Jj^?`j zde_xY4hR<|WWK(e_Bw%?esUrRzRkdWo9Wc%1W|o?q>0;-;kX bkN7d}uN)x{uXfDz+XNC#>Imftf5!g@T}XnT literal 0 HcmV?d00001 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 102f48fc..863988cc 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -43,6 +43,7 @@ from .vqa import * from .fce_aug import * from .fce_targets import FCENetTargets +from .ct_process import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/ct_process.py b/ppocr/data/imaug/ct_process.py new file mode 100644 index 00000000..59715090 --- /dev/null +++ b/ppocr/data/imaug/ct_process.py @@ -0,0 +1,355 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import random +import pyclipper +import paddle + +import numpy as np +import Polygon as plg +import scipy.io as scio + +from PIL import Image +import paddle.vision.transforms as transforms + + +class RandomScale(): + def __init__(self, short_size=640, **kwargs): + self.short_size = short_size + + def scale_aligned(self, img, scale): + oh, ow = img.shape[0:2] + h = int(oh * scale + 0.5) + w = int(ow * scale + 0.5) + if h % 32 != 0: + h = h + (32 - h % 32) + if w % 32 != 0: + w = w + (32 - w % 32) + img = cv2.resize(img, dsize=(w, h)) + factor_h = h / oh + factor_w = w / ow + return img, factor_h, factor_w + + def __call__(self, data): + img = data['image'] + + h, w = img.shape[0:2] + random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]) + scale = (np.random.choice(random_scale) * self.short_size) / min(h, w) + img, factor_h, factor_w = self.scale_aligned(img, scale) + + data['scale_factor'] = (factor_w, factor_h) + data['image'] = img + return data + + +class MakeShrink(): + def __init__(self, kernel_scale=0.7, **kwargs): + self.kernel_scale = kernel_scale + + def dist(self, a, b): + return np.linalg.norm((a - b), ord=2, axis=0) + + def perimeter(self, bbox): + peri = 0.0 + for i in range(bbox.shape[0]): + peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) + return peri + + def shrink(self, bboxes, rate, max_shr=20): + rate = rate * rate + shrinked_bboxes = [] + for bbox in bboxes: + area = plg.Polygon(bbox).area() + peri = self.perimeter(bbox) + + try: + pco = pyclipper.PyclipperOffset() + pco.AddPath(bbox, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + offset = min( + int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) + + shrinked_bbox = pco.Execute(-offset) + if len(shrinked_bbox) == 0: + shrinked_bboxes.append(bbox) + continue + + shrinked_bbox = np.array(shrinked_bbox[0]) + if shrinked_bbox.shape[0] <= 2: + shrinked_bboxes.append(bbox) + continue + + shrinked_bboxes.append(shrinked_bbox) + except Exception as e: + shrinked_bboxes.append(bbox) + + return shrinked_bboxes + + def __call__(self, data): + img = data['image'] + bboxes = data['polys'] + words = data['texts'] + scale_factor = data['scale_factor'] + + gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w + training_mask = np.ones(img.shape[0:2], dtype='uint8') + training_mask_distance = np.ones(img.shape[0:2], dtype='uint8') + + for i in range(len(bboxes)): + bboxes[i] = np.reshape(bboxes[i] * ( + [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)), + (bboxes[i].shape[0] // 2, 2)).astype('int32') + + for i in range(len(bboxes)): + #different value for different bbox + cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) + + # set training mask to 0 + cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) + + # for not accurate annotation, use training_mask_distance + if words[i] == '###' or words[i] == '???': + cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1) + + # make shrink + gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8') + kernel_bboxes = self.shrink(bboxes, self.kernel_scale) + for i in range(len(bboxes)): + cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1, + -1) + + # for training mask, kernel and background= 1, box region=0 + if words[i] != '###' and words[i] != '???': + cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1) + + gt_kernel = gt_kernel_instance.copy() + # for gt_kernel, kernel = 1 + gt_kernel[gt_kernel > 0] = 1 + + # shrink 2 times + tmp1 = gt_kernel_instance.copy() + erode_kernel = np.ones((3, 3), np.uint8) + tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1) + tmp2 = tmp1.copy() + tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1) + + # compute text region + gt_kernel_inner = tmp1 - tmp2 + + # gt_instance: text instance, bg=0, diff word use diff value + # training_mask: text instance mask, word=0,kernel and bg=1 + # gt_kernel_instance: text kernel instance, bg=0, diff word use diff value + # gt_kernel: text_kernel, bg=0,diff word use same value + # gt_kernel_inner: text kernel reference + # training_mask_distance: word without anno = 0, else 1 + + data['image'] = [ + img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, + gt_kernel_inner, training_mask_distance + ] + return data + + +class GroupRandomHorizontalFlip(): + def __init__(self, p=0.5, **kwargs): + self.p = p + + def __call__(self, data): + imgs = data['image'] + + if random.random() < self.p: + for i in range(len(imgs)): + imgs[i] = np.flip(imgs[i], axis=1).copy() + data['image'] = imgs + return data + + +class GroupRandomRotate(): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + imgs = data['image'] + + max_angle = 10 + angle = random.random() * 2 * max_angle - max_angle + for i in range(len(imgs)): + img = imgs[i] + w, h = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) + img_rotation = cv2.warpAffine( + img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) + imgs[i] = img_rotation + + data['image'] = imgs + return data + + +class GroupRandomCropPadding(): + def __init__(self, target_size=(640, 640), **kwargs): + self.target_size = target_size + + def __call__(self, data): + imgs = data['image'] + + h, w = imgs[0].shape[0:2] + t_w, t_h = self.target_size + p_w, p_h = self.target_size + if w == t_w and h == t_h: + return data + + t_h = t_h if t_h < h else h + t_w = t_w if t_w < w else w + + if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: + # make sure to crop the text region + tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) + tl[tl < 0] = 0 + br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) + br[br < 0] = 0 + br[0] = min(br[0], h - t_h) + br[1] = min(br[1], w - t_w) + + i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + i = random.randint(0, h - t_h) if h - t_h > 0 else 0 + j = random.randint(0, w - t_w) if w - t_w > 0 else 0 + + n_imgs = [] + for idx in range(len(imgs)): + if len(imgs[idx].shape) == 3: + s3_length = int(imgs[idx].shape[-1]) + img = imgs[idx][i:i + t_h, j:j + t_w, :] + img_p = cv2.copyMakeBorder( + img, + 0, + p_h - t_h, + 0, + p_w - t_w, + borderType=cv2.BORDER_CONSTANT, + value=tuple(0 for i in range(s3_length))) + else: + img = imgs[idx][i:i + t_h, j:j + t_w] + img_p = cv2.copyMakeBorder( + img, + 0, + p_h - t_h, + 0, + p_w - t_w, + borderType=cv2.BORDER_CONSTANT, + value=(0, )) + n_imgs.append(img_p) + + data['image'] = n_imgs + return data + + +class MakeCentripetalShift(): + def __init__(self, **kwargs): + pass + + def jaccard(self, As, Bs): + A = As.shape[0] # small + B = Bs.shape[0] # large + + dis = np.sqrt( + np.sum((As[:, np.newaxis, :].repeat( + B, axis=1) - Bs[np.newaxis, :, :].repeat( + A, axis=0))**2, + axis=-1)) + + ind = np.argmin(dis, axis=-1) + + return ind + + def __call__(self, data): + imgs = data['image'] + + img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \ + imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6] + + max_instance = np.max(gt_instance) # num bbox + + # make centripetal shift + gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32) + for i in range(1, max_instance + 1): + # kernel_reference + ind = (gt_kernel_inner == i) + + if np.sum(ind) == 0: + training_mask[gt_instance == i] = 0 + training_mask_distance[gt_instance == i] = 0 + continue + + kpoints = np.array(np.where(ind)).transpose( + (1, 0))[:, ::-1].astype('float32') + + ind = (gt_instance == i) * (gt_kernel_instance == 0) + if np.sum(ind) == 0: + continue + pixels = np.where(ind) + + points = np.array(pixels).transpose( + (1, 0))[:, ::-1].astype('float32') + + bbox_ind = self.jaccard(points, kpoints) + + offset_gt = kpoints[bbox_ind] - points + + gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1 + + img = Image.fromarray(img) + img = img.convert('RGB') + + data["image"] = img + data["gt_kernel"] = gt_kernel.astype("int64") + data["training_mask"] = training_mask.astype("int64") + data["gt_instance"] = gt_instance.astype("int64") + data["gt_kernel_instance"] = gt_kernel_instance.astype("int64") + data["training_mask_distance"] = training_mask_distance.astype("int64") + data["gt_distance"] = gt_distance.astype("float32") + + return data + + +class ScaleAlignedShort(): + def __init__(self, short_size=640, **kwargs): + self.short_size = short_size + + def __call__(self, data): + img = data['image'] + + org_img_shape = img.shape + + h, w = img.shape[0:2] + scale = self.short_size * 1.0 / min(h, w) + h = int(h * scale + 0.5) + w = int(w * scale + 0.5) + if h % 32 != 0: + h = h + (32 - h % 32) + if w % 32 != 0: + w = w + (32 - w % 32) + img = cv2.resize(img, dsize=(w, h)) + + new_img_shape = img.shape + img_shape = np.array(org_img_shape + new_img_shape) + + data['shape'] = img_shape + data['image'] = img + + return data \ No newline at end of file diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 59cb9b8a..dbfb9317 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1395,3 +1395,29 @@ class VLLabelEncode(BaseRecLabelEncode): data['label_res'] = np.array(label_res) data['label_sub'] = np.array(label_sub) return data + + +class CTLabelEncode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = data['label'] + + label = json.loads(label) + nBox = len(label) + boxes, txts = [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + box = np.array(box) + + boxes.append(box) + txt = label[bno]['transcription'] + txts.append(txt) + + if len(boxes) == 0: + return None + + data['polys'] = boxes + data['texts'] = txts + return data \ No newline at end of file diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 1a117789..02525b3d 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,6 +25,7 @@ from .det_east_loss import EASTLoss from .det_sast_loss import SASTLoss from .det_pse_loss import PSELoss from .det_fce_loss import FCELoss +from .det_ct_loss import CTLoss # rec loss from .rec_ctc_loss import CTCLoss @@ -68,7 +69,7 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss' + 'SLALoss', 'CTLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/det_ct_loss.py b/ppocr/losses/det_ct_loss.py new file mode 100755 index 00000000..f48c95be --- /dev/null +++ b/ppocr/losses/det_ct_loss.py @@ -0,0 +1,276 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/shengtao96/CentripetalText/tree/main/models/loss +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +import numpy as np + + +def ohem_single(score, gt_text, training_mask): + # online hard example mining + + pos_num = int(paddle.sum(gt_text > 0.5)) - int( + paddle.sum((gt_text > 0.5) & (training_mask <= 0.5))) + + if pos_num == 0: + # selected_mask = gt_text.copy() * 0 # may be not good + selected_mask = training_mask + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5))) + neg_num = int(min(pos_num * 3, neg_num)) + + if neg_num == 0: + selected_mask = training_mask + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + # hard example + neg_score = score[(gt_text <= 0.5) & (training_mask > 0.5)] + neg_score_sorted = paddle.sort(-neg_score) + threshold = -neg_score_sorted[neg_num - 1] + + selected_mask = ((score >= threshold) | + (gt_text > 0.5)) & (training_mask > 0.5) + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + +def ohem_batch(scores, gt_texts, training_masks): + selected_masks = [] + for i in range(scores.shape[0]): + selected_masks.append( + ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[ + i, :, :])) + + selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32") + return selected_masks + + +def iou_single(a, b, mask, n_class): + EPS = 1e-6 + valid = mask == 1 + a = a[valid] + b = b[valid] + miou = [] + + # iou of each class + for i in range(n_class): + inter = paddle.cast(((a == i) & (b == i)), "float32") + union = paddle.cast(((a == i) | (b == i)), "float32") + + miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS)) + miou = sum(miou) / len(miou) + return miou + + +def iou(a, b, mask, n_class=2, reduce=True): + batch_size = a.shape[0] + + a = a.reshape((batch_size, -1)) + b = b.reshape((batch_size, -1)) + mask = mask.reshape((batch_size, -1)) + + iou = paddle.zeros((batch_size, ), dtype="float32") + for i in range(batch_size): + iou[i] = iou_single(a[i], b[i], mask[i], n_class) + + if reduce: + iou = paddle.mean(iou) + return iou + + +class DiceLoss(nn.Layer): + def __init__(self, loss_weight=1.0): + super(DiceLoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, input, target, mask, reduce=True): + batch_size = input.shape[0] + input = F.sigmoid(input) # scale to 0-1 + + input = input.reshape((batch_size, -1)) + target = paddle.cast(target.reshape((batch_size, -1)), "float32") + mask = paddle.cast(mask.reshape((batch_size, -1)), "float32") + + input = input * mask + target = target * mask + + a = paddle.sum(input * target, axis=1) + b = paddle.sum(input * input, axis=1) + 0.001 + c = paddle.sum(target * target, axis=1) + 0.001 + d = (2 * a) / (b + c) + loss = 1 - d + + loss = self.loss_weight * loss + + if reduce: + loss = paddle.mean(loss) + + return loss + + +class SmoothL1Loss(nn.Layer): + def __init__(self, beta=1.0, loss_weight=1.0): + super(SmoothL1Loss, self).__init__() + self.beta = beta + self.loss_weight = loss_weight + + np_coord = np.zeros(shape=[640, 640, 2], dtype=np.int64) + for i in range(640): + for j in range(640): + np_coord[i, j, 0] = j + np_coord[i, j, 1] = i + np_coord = np_coord.reshape((-1, 2)) + + self.coord = self.create_parameter( + shape=[640 * 640, 2], + dtype="int32", # NOTE: not support "int64" before paddle 2.3.1 + default_initializer=nn.initializer.Assign(value=np_coord)) + self.coord.stop_gradient = True + + def forward_single(self, input, target, mask, beta=1.0, eps=1e-6): + batch_size = input.shape[0] + + diff = paddle.abs(input - target) * mask.unsqueeze(1) + loss = paddle.where(diff < beta, 0.5 * diff * diff / beta, + diff - 0.5 * beta) + loss = paddle.cast(loss.reshape((batch_size, -1)), "float32") + mask = paddle.cast(mask.reshape((batch_size, -1)), "float32") + loss = paddle.sum(loss, axis=-1) + loss = loss / (mask.sum(axis=-1) + eps) + + return loss + + def select_single(self, distance, gt_instance, gt_kernel_instance, + training_mask): + + with paddle.no_grad(): + # paddle 2.3.1, paddle.slice not support: + # distance[:, self.coord[:, 1], self.coord[:, 0]] + select_distance_list = [] + for i in range(2): + tmp1 = distance[i, :] + tmp2 = tmp1[self.coord[:, 1], self.coord[:, 0]] + select_distance_list.append(tmp2.unsqueeze(0)) + select_distance = paddle.concat(select_distance_list, axis=0) + + off_points = paddle.cast( + self.coord, "float32") + 10 * select_distance.transpose((1, 0)) + + off_points = paddle.cast(off_points, "int64") + off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1) + + selected_mask = ( + gt_instance[self.coord[:, 1], self.coord[:, 0]] != + gt_kernel_instance[off_points[:, 1], off_points[:, 0]]) + selected_mask = paddle.cast( + selected_mask.reshape((1, -1, distance.shape[-1])), "int64") + selected_training_mask = selected_mask * training_mask + + return selected_training_mask + + def forward(self, + distances, + gt_instances, + gt_kernel_instances, + training_masks, + gt_distances, + reduce=True): + + selected_training_masks = [] + for i in range(distances.shape[0]): + selected_training_masks.append( + self.select_single(distances[i, :, :, :], gt_instances[i, :, :], + gt_kernel_instances[i, :, :], training_masks[ + i, :, :])) + selected_training_masks = paddle.cast( + paddle.concat(selected_training_masks, 0), "float32") + + loss = self.forward_single(distances, gt_distances, + selected_training_masks, self.beta) + loss = self.loss_weight * loss + + with paddle.no_grad(): + batch_size = distances.shape[0] + false_num = selected_training_masks.reshape((batch_size, -1)) + false_num = false_num.sum(axis=-1) + total_num = paddle.cast( + training_masks.reshape((batch_size, -1)), "float32") + total_num = total_num.sum(axis=-1) + iou_text = (total_num - false_num) / (total_num + 1e-6) + + if reduce: + loss = paddle.mean(loss) + + return loss, iou_text + + +class CTLoss(nn.Layer): + def __init__(self): + super(CTLoss, self).__init__() + self.kernel_loss = DiceLoss() + self.loc_loss = SmoothL1Loss(beta=0.1, loss_weight=0.05) + + def forward(self, preds, batch): + imgs = batch[0] + out = preds['maps'] + gt_kernels, training_masks, gt_instances, gt_kernel_instances, training_mask_distances, gt_distances = batch[ + 1:] + + kernels = out[:, 0, :, :] + distances = out[:, 1:, :, :] + + # kernel loss + selected_masks = ohem_batch(kernels, gt_kernels, training_masks) + + loss_kernel = self.kernel_loss( + kernels, gt_kernels, selected_masks, reduce=False) + + iou_kernel = iou(paddle.cast((kernels > 0), "int64"), + gt_kernels, + training_masks, + reduce=False) + losses = dict(loss_kernels=loss_kernel, ) + + # loc loss + loss_loc, iou_text = self.loc_loss( + distances, + gt_instances, + gt_kernel_instances, + training_mask_distances, + gt_distances, + reduce=False) + losses.update(dict(loss_loc=loss_loc, )) + + loss_all = loss_kernel + loss_loc + losses = {'loss': loss_all} + + return losses diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 853647c0..a39d0a46 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -31,12 +31,14 @@ from .kie_metric import KIEMetric from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_re_metric import VQAReTokenMetric from .sr_metric import SRMetric +from .ct_metric import CTMetric + def build_metric(config): support_dict = [ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', - 'VQAReTokenMetric', 'SRMetric' + 'VQAReTokenMetric', 'SRMetric', 'CTMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/ct_metric.py b/ppocr/metrics/ct_metric.py new file mode 100644 index 00000000..a7634230 --- /dev/null +++ b/ppocr/metrics/ct_metric.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from scipy import io +import numpy as np + +from ppocr.utils.e2e_metric.Deteval import combine_results, get_score_C + + +class CTMetric(object): + def __init__(self, main_indicator, delimiter='\t', **kwargs): + self.delimiter = delimiter + self.main_indicator = main_indicator + self.reset() + + def reset(self): + self.results = [] # clear results + + def __call__(self, preds, batch, **kwargs): + # NOTE: only support bs=1 now, as the label length of different sample is Unequal + assert len( + preds) == 1, "CentripetalText test now only suuport batch_size=1." + label = batch[2] + text = batch[3] + pred = preds[0]['points'] + result = get_score_C(label, text, pred) + + self.results.append(result) + + def get_metric(self): + """ + Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' + """ + metrics = combine_results(self.results, rec_flag=False) + self.reset() + return metrics diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 0feda6c6..751757e5 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,6 +23,7 @@ def build_head(config): from .det_pse_head import PSEHead from .det_fce_head import FCEHead from .e2e_pg_head import PGHead + from .det_ct_head import CT_Head # rec head from .rec_ctc_head import CTCHead @@ -52,7 +53,7 @@ def build_head(config): 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', - 'VLHead', 'SLAHead', 'RobustScannerHead' + 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head' ] #table head diff --git a/ppocr/modeling/heads/det_ct_head.py b/ppocr/modeling/heads/det_ct_head.py new file mode 100644 index 00000000..08e6719e --- /dev/null +++ b/ppocr/modeling/heads/det_ct_head.py @@ -0,0 +1,69 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + +import math +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +ones_ = Constant(value=1.) +zeros_ = Constant(value=0.) + + +class CT_Head(nn.Layer): + def __init__(self, + in_channels, + hidden_dim, + num_classes, + loss_kernel=None, + loss_loc=None): + super(CT_Head, self).__init__() + self.conv1 = nn.Conv2D( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(hidden_dim) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2D( + hidden_dim, num_classes, kernel_size=1, stride=1, padding=0) + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_ = Normal(mean=0.0, std=math.sqrt(2. / n)) + normal_(m.weight) + elif isinstance(m, nn.BatchNorm2D): + zeros_(m.bias) + ones_(m.weight) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def forward(self, f, targets=None): + out = self.conv1(f) + out = self.relu1(self.bn1(out)) + out = self.conv2(out) + + if self.training: + out = self._upsample(out, scale=4) + return {'maps': out} + else: + score = F.sigmoid(out[:, 0, :, :]) + return {'maps': out, 'score': score} diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index e3ae2d6e..c7e8dd06 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -26,13 +26,15 @@ def build_neck(config): from .fce_fpn import FCEFPN from .pren_fpn import PRENFPN from .csp_pan import CSPPAN + from .ct_fpn import CTFPN support_dict = [ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', - 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN' + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN' ] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( support_dict)) + module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/necks/ct_fpn.py b/ppocr/modeling/necks/ct_fpn.py new file mode 100644 index 00000000..ee4d25e9 --- /dev/null +++ b/ppocr/modeling/necks/ct_fpn.py @@ -0,0 +1,185 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr +import os +import sys + +import math +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +ones_ = Constant(value=1.) +zeros_ = Constant(value=0.) + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..'))) + + +class Conv_BN_ReLU(nn.Layer): + def __init__(self, + in_planes, + out_planes, + kernel_size=1, + stride=1, + padding=0): + super(Conv_BN_ReLU, self).__init__() + self.conv = nn.Conv2D( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias_attr=False) + self.bn = nn.BatchNorm2D(out_planes) + self.relu = nn.ReLU() + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_ = Normal(mean=0.0, std=math.sqrt(2. / n)) + normal_(m.weight) + elif isinstance(m, nn.BatchNorm2D): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class FPEM(nn.Layer): + def __init__(self, in_channels, out_channels): + super(FPEM, self).__init__() + planes = out_channels + self.dwconv3_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv2_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv1_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv2_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes) + + self.dwconv3_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes) + + self.dwconv4_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes) + + def _upsample_add(self, x, y): + return F.upsample(x, scale_factor=2, mode='bilinear') + y + + def forward(self, f1, f2, f3, f4): + # up-down + f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3))) + f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2))) + f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1))) + + # down-up + f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1))) + f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2))) + f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3))) + + return f1, f2, f3, f4 + + +class CTFPN(nn.Layer): + def __init__(self, in_channels, out_channel=128): + super(CTFPN, self).__init__() + self.out_channels = out_channel * 4 + + self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128) + self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128) + self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128) + self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128) + + self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128) + self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def forward(self, f): + # # reduce channel + f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160 + f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80 + f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40 + f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20 + + # FPEM + f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4) + f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1) + + # FFM + f1 = f1_1 + f1_2 + f2 = f2_1 + f2_2 + f3 = f3_1 + f3_2 + f4 = f4_1 + f4_2 + + f2 = self._upsample(f2, scale=2) + f3 = self._upsample(f3, scale=4) + f4 = self._upsample(f4, scale=8) + ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160 + return ff diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 8f41a005..35b7a680 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -35,6 +35,7 @@ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .picodet_postprocess import PicoDetPostProcess +from .ct_postprocess import CTPostProcess def build_post_process(config, global_config=None): @@ -48,7 +49,7 @@ def build_post_process(config, global_config=None): 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', - 'VLLabelDecode', 'PicoDetPostProcess' + 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/ct_postprocess.py b/ppocr/postprocess/ct_postprocess.py new file mode 100755 index 00000000..3ab90be2 --- /dev/null +++ b/ppocr/postprocess/ct_postprocess.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refered from: +https://github.com/shengtao96/CentripetalText/blob/main/test.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import numpy as np +import cv2 +import paddle +import pyclipper + + +class CTPostProcess(object): + """ + The post process for Centripetal Text (CT). + """ + + def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs): + self.min_score = min_score + self.min_area = min_area + self.box_type = box_type + + self.coord = np.zeros((2, 300, 300), dtype=np.int32) + for i in range(300): + for j in range(300): + self.coord[0, i, j] = j + self.coord[1, i, j] = i + + def __call__(self, preds, batch): + outs = preds['maps'] + out_scores = preds['score'] + + if isinstance(outs, paddle.Tensor): + outs = outs.numpy() + if isinstance(out_scores, paddle.Tensor): + out_scores = out_scores.numpy() + + batch_size = outs.shape[0] + boxes_batch = [] + for idx in range(batch_size): + bboxes = [] + scores = [] + + img_shape = batch[idx] + + org_img_size = img_shape[:3] + img_shape = img_shape[3:] + img_size = img_shape[:2] + + out = np.expand_dims(outs[idx], axis=0) + outputs = dict() + + score = np.expand_dims(out_scores[idx], axis=0) + + kernel = out[:, 0, :, :] > 0.2 + loc = out[:, 1:, :, :].astype("float32") + + score = score[0].astype(np.float32) + kernel = kernel[0].astype(np.uint8) + loc = loc[0].astype(np.float32) + + label_num, label_kernel = cv2.connectedComponents( + kernel, connectivity=4) + + for i in range(1, label_num): + ind = (label_kernel == i) + if ind.sum( + ) < 10: # pixel number less than 10, treated as background + label_kernel[ind] = 0 + + label = np.zeros_like(label_kernel) + h, w = label_kernel.shape + pixels = self.coord[:, :h, :w].reshape(2, -1) + points = pixels.transpose([1, 0]).astype(np.float32) + + off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T + ).astype(np.int32) + off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1) + off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1) + + label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1], + off_points[:, 0]] + label[label_kernel > 0] = label_kernel[label_kernel > 0] + + score_pocket = [0.0] + for i in range(1, label_num): + ind = (label_kernel == i) + if ind.sum() == 0: + score_pocket.append(0.0) + continue + score_i = np.mean(score[ind]) + score_pocket.append(score_i) + + label_num = np.max(label) + 1 + label = cv2.resize( + label, (img_size[1], img_size[0]), + interpolation=cv2.INTER_NEAREST) + + scale = (float(org_img_size[1]) / float(img_size[1]), + float(org_img_size[0]) / float(img_size[0])) + + for i in range(1, label_num): + ind = (label == i) + points = np.array(np.where(ind)).transpose((1, 0)) + + if points.shape[0] < self.min_area: + continue + + score_i = score_pocket[i] + if score_i < self.min_score: + continue + + if self.box_type == 'rect': + rect = cv2.minAreaRect(points[:, ::-1]) + bbox = cv2.boxPoints(rect) * scale + z = bbox.mean(0) + bbox = z + (bbox - z) * 0.85 + elif self.box_type == 'poly': + binary = np.zeros(label.shape, dtype='uint8') + binary[ind] = 1 + try: + _, contours, _ = cv2.findContours( + binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + except BaseException: + contours, _ = cv2.findContours( + binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + bbox = contours[0] * scale + + bbox = bbox.astype('int32') + bboxes.append(bbox.reshape(-1, 2)) + scores.append(score_i) + + boxes_batch.append({'points': bboxes}) + + return boxes_batch diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 45567a7d..6ce56eda 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import numpy as np import scipy.io as io +import Polygon as plg from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area @@ -269,7 +271,124 @@ def get_socre_B(gt_dir, img_id, pred_dict): return single_data -def combine_results(all_data): +def get_score_C(gt_label, text, pred_bboxes): + """ + get score for CentripetalText (CT) prediction. + """ + + def gt_reading_mod(gt_label, text): + """This helper reads groundtruths from mat files""" + groundtruths = [] + nbox = len(gt_label) + for i in range(nbox): + label = {"transcription": text[i][0], "points": gt_label[i].numpy()} + groundtruths.append(label) + + return groundtruths + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def detection_filtering(detections, groundtruths, threshold=0.5): + for gt in groundtruths: + point_num = gt['points'].shape[1] // 2 + if gt['transcription'] == '###' and (point_num > 1): + gt_p = np.array(gt['points']).reshape(point_num, + 2).astype('int32') + gt_p = plg.Polygon(gt_p) + + for det_id, detection in enumerate(detections): + det_y = detection[0::2] + det_x = detection[1::2] + + det_p = np.concatenate((np.array(det_x), np.array(det_y))) + det_p = det_p.reshape(2, -1).transpose() + det_p = plg.Polygon(det_p) + + try: + det_gt_iou = get_intersection(det_p, + gt_p) / det_p.area() + except: + print(det_x, det_y, gt_p) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_p, gt_p): + """ + sigma = inter_area / gt_area + """ + if gt_p.area() == 0.: + return 0 + return get_intersection(det_p, gt_p) / gt_p.area() + + def tau_calculation(det_p, gt_p): + """ + tau = inter_area / det_area + """ + if det_p.area() == 0.: + return 0 + return get_intersection(det_p, gt_p) / det_p.area() + + detections = [] + + for item in pred_bboxes: + detections.append(item[:, ::-1].reshape(-1)) + + groundtruths = gt_reading_mod(gt_label, text) + + detections = detection_filtering( + detections, groundtruths) # filters detections overlapping with DC area + + for idx in range(len(groundtruths) - 1, -1, -1): + #NOTE: source code use 'orin' to indicate '#', here we use 'anno', + # which may cause slight drop in fscore, about 0.12 + if groundtruths[idx]['transcription'] == '###': + groundtruths.pop(idx) + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + point_num = gt['points'].shape[1] // 2 + + gt_p = np.array(gt['points']).reshape(point_num, + 2).astype('int32') + gt_p = plg.Polygon(gt_p) + + det_y = detection[0::2] + det_x = detection[1::2] + + det_p = np.concatenate((np.array(det_x), np.array(det_y))) + + det_p = det_p.reshape(2, -1).transpose() + det_p = plg.Polygon(det_p) + + local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, + gt_p) + local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p) + + data = {} + data['sigma'] = local_sigma_table + data['global_tau'] = local_tau_table + data['global_pred_str'] = '' + data['global_gt_str'] = '' + return data + + +def combine_results(all_data, rec_flag=True): tr = 0.7 tp = 0.6 fsc_k = 0.8 @@ -278,6 +397,7 @@ def combine_results(all_data): global_tau = [] global_pred_str = [] global_gt_str = [] + for data in all_data: global_sigma.append(data['sigma']) global_tau.append(data['global_tau']) @@ -294,7 +414,7 @@ def combine_results(all_data): def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for gt_id in range(num_gt): gt_matching_qualified_sigma_candidates = np.where( @@ -328,14 +448,15 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ - 0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0] + .tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end det_flag[0, matched_det_id] = 1 return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num @@ -343,7 +464,7 @@ def combine_results(all_data): def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for gt_id in range(num_gt): # skip the following if the groundtruth was matched @@ -374,28 +495,30 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end global_accumulative_recall = global_accumulative_recall + fsc_k @@ -409,7 +532,7 @@ def combine_results(all_data): def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for det_id in range(num_det): # skip the following if the detection was matched @@ -440,6 +563,30 @@ def combine_results(all_data): gt_flag[0, qualified_sigma_candidates] = 1 det_flag[0, det_id] = 1 # recg start + if rec_flag: + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[ + 0].tolist()[idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower( + ): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + if rec_flag: pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): @@ -454,27 +601,7 @@ def combine_results(all_data): else: if pred_str_cur.lower() == gt_str_cur.lower(): hit_str_num += 1 - break - # recg end - elif (np.sum(local_tau_table[qualified_sigma_candidates, - det_id]) >= tp): - det_flag[0, det_id] = 1 - gt_flag[0, qualified_sigma_candidates] = 1 - # recg start - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if ele_gt_id not in global_gt_str[idy]: - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break + break # recg end global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k @@ -504,7 +631,7 @@ def combine_results(all_data): gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num #######then check for one-to-many case########## @@ -512,14 +639,14 @@ def combine_results(all_data): gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num #######then check for many-to-one case########## local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num try: diff --git a/requirements.txt b/requirements.txt index 2c0741a0..43cd8c1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ lxml premailer openpyxl attrdict +Polygon3 diff --git a/test_tipc/configs/det_r18_ct/train_infer_python.txt b/test_tipc/configs/det_r18_ct/train_infer_python.txt new file mode 100644 index 00000000..5933fdbe --- /dev/null +++ b/test_tipc/configs/det_r18_ct/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det_r18_ct +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./train_data/total_text/test/rgb/ +null:null +## +trainer:norm_train +norm_train:tools/train.py -c configs/det/det_r18_vd_ct.yml -o Global.print_batch_step=1 Train.loader.shuffle=false +quant_export:null +fpgm_export:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c configs/det/det_r18_vd_ct.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/det_r18_vd_ct/best_accuracy +infer_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o +infer_quant:False +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +--save_log_path:null +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] \ No newline at end of file diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 5d50a5ad..657d4964 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -264,6 +264,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then cd ./train_data/ && tar xf XFUND.tar cd ../ fi + if [ ${model_name} == "det_r18_ct" ]; then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ct_tipc/total_text_lite2.tar --no-check-certificate + cd ./train_data && tar xf total_text_lite2.tar && ln -s total_text_lite2 total_text && cd ../ + fi elif [ ${MODE} = "whole_train_whole_infer" ];then wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 9f5c480d..00fa2e9b 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -127,6 +127,9 @@ class TextDetector(object): postprocess_params["beta"] = args.beta postprocess_params["fourier_degree"] = args.fourier_degree postprocess_params["box_type"] = args.det_fce_box_type + elif self.det_algorithm == "CT": + pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}} + postprocess_params['name'] = 'CTPostProcess' else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -253,6 +256,9 @@ class TextDetector(object): elif self.det_algorithm == 'FCE': for i, output in enumerate(outputs): preds['level_{}'.format(i)] = output + elif self.det_algorithm == "CT": + preds['maps'] = outputs[0] + preds['score'] = outputs[1] else: raise NotImplementedError @@ -260,7 +266,7 @@ class TextDetector(object): post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( - self.det_algorithm in ["PSE", "FCE"] and + self.det_algorithm in ["PSE", "FCE", "CT"] and self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: diff --git a/tools/program.py b/tools/program.py index c91e66fd..9117d51b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -625,7 +625,7 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner' + 'Gestalt', 'SLANet', 'RobustScanner', 'CT' ] if use_xpu: diff --git a/tools/train.py b/tools/train.py index d0f20018..970a5262 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) + use_sync_bn = config["Global"].get("use_sync_bn", False) if use_sync_bn: model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -138,7 +139,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( @@ -146,7 +147,7 @@ def main(config, device, logger, vdl_writer): use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') - amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) + amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -161,20 +162,24 @@ def main(config, device, logger, vdl_writer): use_dynamic_loss_scaling=use_dynamic_loss_scaling) if amp_level == "O2": model, optimizer = paddle.amp.decorate( - models=model, optimizers=optimizer, level=amp_level, master_weight=True) + models=model, + optimizers=optimizer, + level=amp_level, + master_weight=True) else: scaler = None # load pretrain model pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"]) - + if config['Global']['distributed']: model = paddle.DataParallel(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler, + amp_level, amp_custom_black_list) def test_reader(config, device, logger): diff --git a/train.sh b/train.sh index 4225470c..6fa04ea3 100644 --- a/train.sh +++ b/train.sh @@ -1,2 +1,2 @@ # recommended paddle.__version__ == 2.0.0 -python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml +python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml \ No newline at end of file -- GitLab