diff --git a/docker/hubserving/README.md b/docker/hubserving/README.md new file mode 100644 index 0000000000000000000000000000000000000000..71e2377dcc4f7524384752b95c53f02471353f34 --- /dev/null +++ b/docker/hubserving/README.md @@ -0,0 +1,58 @@ +English | [简体中文](README_cn.md) + +## Introduction +Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment. + +This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue) + +## 1. Prerequisites + +You need to install the following basic components first: +a. Docker +b. Graphics driver and CUDA 10.0+(GPU) +c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this) +d. cuDNN 7.6+(GPU) + +## 2. Build Image +a. Download PaddleOCR sourcecode +``` +git clone https://github.com/PaddlePaddle/PaddleOCR.git +``` +b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword) +``` +cd docker/cpu +``` +c. Build image +``` +docker build -t paddleocr:cpu . +``` + +## 3. Start container +a. CPU version +``` +sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu +``` +b. GPU version (base on NVIDIA Container Toolkit) +``` +sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu +``` +c. GPU version (Docker 19.03++) +``` +sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu +``` +d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/) +``` +docker logs -f paddle_ocr +``` + +## 4. Test +a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/) +b. Post a service request(sample request in sample_request.txt) + +``` +curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system +``` +c. Get resposne(If the call is successful, the following result will be returned) +``` +{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"} +``` diff --git a/docker/hubserving/README_cn.md b/docker/hubserving/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..9b9e5f50f5b22f3a2125a656112a20542010ac68 --- /dev/null +++ b/docker/hubserving/README_cn.md @@ -0,0 +1,57 @@ +[English](README.md) | 简体中文 + +## Docker化部署服务 +在日常项目应用中,相信大家一般都会希望能通过Docker技术,把PaddleOCR服务打包成一个镜像,以便在Docker或k8s环境里,快速发布上线使用。 + +本文将提供一些标准化的代码来实现这样的目标。大家通过如下步骤可以把PaddleOCR项目快速发布成可调用的Restful API服务。(目前暂时先实现了基于HubServing模式的部署,后续作者计划增加PaddleServing模式的部署) + +## 1.实施前提准备 + +需要先完成如下基本组件的安装: +a. Docker环境 +b. 显卡驱动和CUDA 10.0+(GPU) +c. NVIDIA Container Toolkit(GPU,Docker 19.03以上版本可以跳过此步) +d. cuDNN 7.6+(GPU) + +## 2.制作镜像 +a.下载PaddleOCR项目代码 +``` +git clone https://github.com/PaddlePaddle/PaddleOCR.git +``` +b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可) +``` +cd docker/cpu +``` +c.生成镜像 +``` +docker build -t paddleocr:cpu . +``` + +## 3.启动Docker容器 +a. CPU 版本 +``` +sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu +``` +b. GPU 版本 (通过NVIDIA Container Toolkit) +``` +sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu +``` +c. GPU 版本 (Docker 19.03以上版本,可以直接用如下命令) +``` +sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu +``` +d. 检查服务运行情况(出现:Successfully installed ocr_system和Running on http://0.0.0.0:8866/等信息,表示运行成功) +``` +docker logs -f paddle_ocr +``` + +## 4.测试服务 +a. 计算待识别图片的Base64编码(如果只是测试一下效果,可以通过免费的在线工具实现,如:http://tool.chinaz.com/tools/imgtobase/) +b. 发送服务请求(可参见sample_request.txt中的值) +``` +curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"填入图片Base64编码(需要删除'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system +``` +c. 返回结果(如果调用成功,会返回如下结果) +``` +{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"} +``` diff --git a/docker/hubserving/readme.md b/docker/hubserving/readme.md index 109e6aa6a536c146095b8f46516b1c895dc08337..71e2377dcc4f7524384752b95c53f02471353f34 100644 --- a/docker/hubserving/readme.md +++ b/docker/hubserving/readme.md @@ -1,55 +1,58 @@ -# Docker化部署服务 -在日常项目应用中,相信大家一般都会希望能通过Docker技术,把PaddleOCR服务打包成一个镜像,以便在Docker或k8s环境里,快速发布上线使用。 +English | [简体中文](README_cn.md) -本文将提供一些标准化的代码来实现这样的目标。大家通过如下步骤可以把PaddleOCR项目快速发布成可调用的Restful API服务。(目前暂时先实现了基于HubServing模式的部署,后续作者计划增加PaddleServing模式的部署) +## Introduction +Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment. -## 1.实施前提准备 +This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue) -需要先完成如下基本组件的安装: -a. Docker环境 -b. 显卡驱动和CUDA 10.0+(GPU) -c. NVIDIA Container Toolkit(GPU,Docker 19.03以上版本可以跳过此步) +## 1. Prerequisites + +You need to install the following basic components first: +a. Docker +b. Graphics driver and CUDA 10.0+(GPU) +c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this) d. cuDNN 7.6+(GPU) -## 2.制作镜像 -a.下载PaddleOCR项目代码 +## 2. Build Image +a. Download PaddleOCR sourcecode ``` git clone https://github.com/PaddlePaddle/PaddleOCR.git ``` -b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可) +b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword) ``` cd docker/cpu ``` -c.生成镜像 +c. Build image ``` docker build -t paddleocr:cpu . ``` -## 3.启动Docker容器 -a. CPU 版本 +## 3. Start container +a. CPU version ``` sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu ``` -b. GPU 版本 (通过NVIDIA Container Toolkit) +b. GPU version (base on NVIDIA Container Toolkit) ``` sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu ``` -c. GPU 版本 (Docker 19.03以上版本,可以直接用如下命令) +c. GPU version (Docker 19.03++) ``` sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu ``` -d. 检查服务运行情况(出现:Successfully installed ocr_system和Running on http://0.0.0.0:8866/等信息,表示运行成功) +d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/) ``` docker logs -f paddle_ocr ``` -## 4.测试服务 -a. 计算待识别图片的Base64编码(如果只是测试一下效果,可以通过免费的在线工具实现,如:http://tool.chinaz.com/tools/imgtobase/) -b. 发送服务请求(可参见sample_request.txt中的值) +## 4. Test +a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/) +b. Post a service request(sample request in sample_request.txt) + ``` -curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"填入图片Base64编码(需要删除'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system +curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system ``` -c. 返回结果(如果调用成功,会返回如下结果) +c. Get resposne(If the call is successful, the following result will be returned) ``` {"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"} ``` diff --git a/paddleocr.py b/paddleocr.py index 65bca7ae243e15e4788b5b637be65d57cf9504e5..d3d73cb1b92cb2228fafb4e0efa36ab13207a4b3 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -129,6 +129,7 @@ def parse_args(): parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True) + parser.add_argument("--use_zero_copy_run", type=bool, default=False) return parser.parse_args() @@ -209,4 +210,4 @@ def main(): print(img_path) result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) for line in result: - print(line) + print(line) \ No newline at end of file diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index 67cbf9b53ad7b877be8985d76627cdf97d49f423..84f325b9b880d6289a4d60f7ebff39d962fdb5a1 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -257,6 +257,7 @@ class SimpleReader(object): norm_img = process_image_srn( img=img, image_shape=self.image_shape, + char_ops=self.char_ops, num_heads=self.num_heads, max_text_length=self.max_text_length) else: diff --git a/ppocr/modeling/heads/self_attention/model.py b/ppocr/modeling/heads/self_attention/model.py index 8ac1458b7dca3dedc368a16fa00f52a9aa4f4f93..8bf34e4ac6a2c3c33d2a46b1f4f9dbfaf8db8f57 100644 --- a/ppocr/modeling/heads/self_attention/model.py +++ b/ppocr/modeling/heads/self_attention/model.py @@ -4,8 +4,10 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -# Set seed for CE -dropout_seed = None +encoder_data_input_fields = ( + "src_word", + "src_pos", + "src_slf_attn_bias", ) def wrap_layer_with_block(layer, block_idx): @@ -45,25 +47,6 @@ def wrap_layer_with_block(layer, block_idx): return layer_wrapper -def position_encoding_init(n_position, d_pos_vec): - """ - Generate the initial values for the sinusoid position encoding table. - """ - channels = d_pos_vec - position = np.arange(n_position) - num_timescales = channels // 2 - log_timescale_increment = (np.log(float(1e4) / float(1)) / - (num_timescales - 1)) - inv_timescales = np.exp(np.arange( - num_timescales)) * -log_timescale_increment - scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, - 0) - signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) - signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') - position_enc = signal - return position_enc.astype("float32") - - def multi_head_attention(queries, keys, values, @@ -200,10 +183,7 @@ def multi_head_attention(queries, weights = layers.softmax(product) if dropout_rate: weights = layers.dropout( - weights, - dropout_prob=dropout_rate, - seed=dropout_seed, - is_test=False) + weights, dropout_prob=dropout_rate, seed=None, is_test=False) out = layers.matmul(weights, v) return out @@ -235,7 +215,7 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): act="relu") if dropout_rate: hidden = layers.dropout( - hidden, dropout_prob=dropout_rate, seed=dropout_seed, is_test=False) + hidden, dropout_prob=dropout_rate, seed=None, is_test=False) out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) return out @@ -259,10 +239,7 @@ def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): elif cmd == "d": # add dropout if dropout_rate: out = layers.dropout( - out, - dropout_prob=dropout_rate, - seed=dropout_seed, - is_test=False) + out, dropout_prob=dropout_rate, seed=None, is_test=False) return out @@ -271,7 +248,7 @@ post_process_layer = pre_post_process_layer def prepare_encoder( - src_word, #[b,t,c] + src_word, # [b,t,c] src_pos, src_vocab_size, src_emb_dim, @@ -286,9 +263,8 @@ def prepare_encoder( This module is used at the bottom of the encoder stacks. """ - src_word_emb = src_word #layers.concat(res,axis=1) + src_word_emb = src_word src_word_emb = layers.cast(src_word_emb, 'float32') - # print("src_word_emb",src_word_emb) src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_pos_enc = layers.embedding( @@ -299,7 +275,7 @@ def prepare_encoder( src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc return layers.dropout( - enc_input, dropout_prob=dropout_rate, seed=dropout_seed, + enc_input, dropout_prob=dropout_rate, seed=None, is_test=False) if dropout_rate else enc_input @@ -324,7 +300,7 @@ def prepare_decoder(src_word, param_attr=fluid.ParamAttr( name=word_emb_param_name, initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) - # print("target_word_emb",src_word_emb) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_pos_enc = layers.embedding( src_pos, @@ -334,16 +310,10 @@ def prepare_decoder(src_word, src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc return layers.dropout( - enc_input, dropout_prob=dropout_rate, seed=dropout_seed, + enc_input, dropout_prob=dropout_rate, seed=None, is_test=False) if dropout_rate else enc_input -# prepare_encoder = partial( -# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0]) -# prepare_decoder = partial( -# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1]) - - def encoder_layer(enc_input, attn_bias, n_head, @@ -412,234 +382,6 @@ def encoder(enc_input, return enc_output -def decoder_layer(dec_input, - enc_output, - slf_attn_bias, - dec_enc_attn_bias, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - cache=None, - gather_idx=None): - """ The layer to be stacked in decoder part. - The structure of this module is similar to that in the encoder part except - a multi-head attention is added to implement encoder-decoder attention. - """ - slf_attn_output = multi_head_attention( - pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout), - None, - None, - slf_attn_bias, - d_key, - d_value, - d_model, - n_head, - attention_dropout, - cache=cache, - gather_idx=gather_idx) - slf_attn_output = post_process_layer( - dec_input, - slf_attn_output, - postprocess_cmd, - prepostprocess_dropout, ) - enc_attn_output = multi_head_attention( - pre_process_layer(slf_attn_output, preprocess_cmd, - prepostprocess_dropout), - enc_output, - enc_output, - dec_enc_attn_bias, - d_key, - d_value, - d_model, - n_head, - attention_dropout, - cache=cache, - gather_idx=gather_idx, - static_kv=True) - enc_attn_output = post_process_layer( - slf_attn_output, - enc_attn_output, - postprocess_cmd, - prepostprocess_dropout, ) - ffd_output = positionwise_feed_forward( - pre_process_layer(enc_attn_output, preprocess_cmd, - prepostprocess_dropout), - d_inner_hid, - d_model, - relu_dropout, ) - dec_output = post_process_layer( - enc_attn_output, - ffd_output, - postprocess_cmd, - prepostprocess_dropout, ) - return dec_output - - -def decoder(dec_input, - enc_output, - dec_slf_attn_bias, - dec_enc_attn_bias, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - caches=None, - gather_idx=None): - """ - The decoder is composed of a stack of identical decoder_layer layers. - """ - for i in range(n_layer): - dec_output = decoder_layer( - dec_input, - enc_output, - dec_slf_attn_bias, - dec_enc_attn_bias, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - cache=None if caches is None else caches[i], - gather_idx=gather_idx) - dec_input = dec_output - dec_output = pre_process_layer(dec_output, preprocess_cmd, - prepostprocess_dropout) - return dec_output - - -def make_all_inputs(input_fields): - """ - Define the input data layers for the transformer model. - """ - inputs = [] - for input_field in input_fields: - input_var = layers.data( - name=input_field, - shape=input_descs[input_field][0], - dtype=input_descs[input_field][1], - lod_level=input_descs[input_field][2] - if len(input_descs[input_field]) == 3 else 0, - append_batch_size=False) - inputs.append(input_var) - return inputs - - -def make_all_py_reader_inputs(input_fields, is_test=False): - reader = layers.py_reader( - capacity=20, - name="test_reader" if is_test else "train_reader", - shapes=[input_descs[input_field][0] for input_field in input_fields], - dtypes=[input_descs[input_field][1] for input_field in input_fields], - lod_levels=[ - input_descs[input_field][2] - if len(input_descs[input_field]) == 3 else 0 - for input_field in input_fields - ]) - return layers.read_file(reader), reader - - -def transformer(src_vocab_size, - trg_vocab_size, - max_length, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - label_smooth_eps, - bos_idx=0, - use_py_reader=False, - is_test=False): - if weight_sharing: - assert src_vocab_size == trg_vocab_size, ( - "Vocabularies in source and target should be same for weight sharing." - ) - - data_input_names = encoder_data_input_fields + \ - decoder_data_input_fields[:-1] + label_data_input_fields - - if use_py_reader: - all_inputs, reader = make_all_py_reader_inputs(data_input_names, - is_test) - else: - all_inputs = make_all_inputs(data_input_names) - # print("all inputs",all_inputs) - enc_inputs_len = len(encoder_data_input_fields) - dec_inputs_len = len(decoder_data_input_fields[:-1]) - enc_inputs = all_inputs[0:enc_inputs_len] - dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len] - label = all_inputs[-2] - weights = all_inputs[-1] - - enc_output = wrap_encoder( - src_vocab_size, 64, n_layer, n_head, d_key, d_value, d_model, - d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, - preprocess_cmd, postprocess_cmd, weight_sharing, enc_inputs) - - predict = wrap_decoder( - trg_vocab_size, - max_length, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - dec_inputs, - enc_output, ) - - # Padding index do not contribute to the total loss. The weights is used to - # cancel padding index in calculating the loss. - if label_smooth_eps: - label = layers.label_smooth( - label=layers.one_hot( - input=label, depth=trg_vocab_size), - epsilon=label_smooth_eps) - - cost = layers.softmax_with_cross_entropy( - logits=predict, - label=label, - soft_label=True if label_smooth_eps else False) - weighted_cost = cost * weights - sum_cost = layers.reduce_sum(weighted_cost) - token_num = layers.reduce_sum(weights) - token_num.stop_gradient = True - avg_cost = sum_cost / token_num - return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None - - def wrap_encoder_forFeature(src_vocab_size, max_length, n_layer, @@ -662,44 +404,8 @@ def wrap_encoder_forFeature(src_vocab_size, img """ - if enc_inputs is None: - # This is used to implement independent encoder program in inference. - conv_features, src_pos, src_slf_attn_bias = make_all_inputs( - encoder_data_input_fields) - else: - conv_features, src_pos, src_slf_attn_bias = enc_inputs # - b, t, c = conv_features.shape - #""" - # insert cnn - #""" - #import basemodel - # feat = basemodel.resnet_50(img) - - # mycrnn = basemodel.CRNN() - # feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu) - # b, c, w, h = feat.shape - # src_word = layers.reshape(feat, shape=[-1, c, w * h]) - - #myconv8 = basemodel.conv8() - #feat = myconv8.net(img ) - #b , c, h, w = feat.shape#h=6 - #print(feat) - #layers.Print(feat,message="conv_feat",summarize=10) - - #feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu") - #feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1)) - #src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww] - - #feat = layers.transpose(feat, [0,3,1,2]) - #src_word = layers.reshape(feat,[-1,w, c*h]) - #src_word = layers.im2sequence( - # input=feat, - # stride=[1, 1], - # filter_size=[feat.shape[2], 1]) - #layers.Print(src_word,message="src_word",summarize=10) - - # print('feat',feat) - #print("src_word",src_word) + conv_features, src_pos, src_slf_attn_bias = enc_inputs # + b, t, c = conv_features.shape enc_input = prepare_encoder( conv_features, @@ -749,43 +455,9 @@ def wrap_encoder(src_vocab_size, img, src_pos, src_slf_attn_bias = enc_inputs img """ - if enc_inputs is None: - # This is used to implement independent encoder program in inference. - src_word, src_pos, src_slf_attn_bias = make_all_inputs( - encoder_data_input_fields) - else: - src_word, src_pos, src_slf_attn_bias = enc_inputs # - #""" - # insert cnn - #""" - #import basemodel - # feat = basemodel.resnet_50(img) - - # mycrnn = basemodel.CRNN() - # feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu) - # b, c, w, h = feat.shape - # src_word = layers.reshape(feat, shape=[-1, c, w * h]) - #myconv8 = basemodel.conv8() - #feat = myconv8.net(img ) - #b , c, h, w = feat.shape#h=6 - #print(feat) - #layers.Print(feat,message="conv_feat",summarize=10) + src_word, src_pos, src_slf_attn_bias = enc_inputs # - #feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu") - #feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1)) - #src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww] - - #feat = layers.transpose(feat, [0,3,1,2]) - #src_word = layers.reshape(feat,[-1,w, c*h]) - #src_word = layers.im2sequence( - # input=feat, - # stride=[1, 1], - # filter_size=[feat.shape[2], 1]) - #layers.Print(src_word,message="src_word",summarize=10) - - # print('feat',feat) - #print("src_word",src_word) enc_input = prepare_decoder( src_word, src_pos, @@ -811,248 +483,3 @@ def wrap_encoder(src_vocab_size, preprocess_cmd, postprocess_cmd, ) return enc_output - - -def wrap_decoder(trg_vocab_size, - max_length, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - dec_inputs=None, - enc_output=None, - caches=None, - gather_idx=None, - bos_idx=0): - """ - The wrapper assembles together all needed layers for the decoder. - """ - if dec_inputs is None: - # This is used to implement independent decoder program in inference. - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \ - make_all_inputs(decoder_data_input_fields) - else: - trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs - - dec_input = prepare_decoder( - trg_word, - trg_pos, - trg_vocab_size, - d_model, - max_length, - prepostprocess_dropout, - bos_idx=bos_idx, - word_emb_param_name="src_word_emb_table" - if weight_sharing else "trg_word_emb_table") - dec_output = decoder( - dec_input, - enc_output, - trg_slf_attn_bias, - trg_src_attn_bias, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - caches=caches, - gather_idx=gather_idx) - return dec_output - # Reshape to 2D tensor to use GEMM instead of BatchedGEMM - dec_output = layers.reshape( - dec_output, shape=[-1, dec_output.shape[-1]], inplace=True) - if weight_sharing: - predict = layers.matmul( - x=dec_output, - y=fluid.default_main_program().global_block().var( - "trg_word_emb_table"), - transpose_y=True) - else: - predict = layers.fc(input=dec_output, - size=trg_vocab_size, - bias_attr=False) - if dec_inputs is None: - # Return probs for independent decoder program. - predict = layers.softmax(predict) - return predict - - -def fast_decode(src_vocab_size, - trg_vocab_size, - max_in_len, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - beam_size, - max_out_len, - bos_idx, - eos_idx, - use_py_reader=False): - """ - Use beam search to decode. Caches will be used to store states of history - steps which can make the decoding faster. - """ - data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields - - if use_py_reader: - all_inputs, reader = make_all_py_reader_inputs(data_input_names) - else: - all_inputs = make_all_inputs(data_input_names) - - enc_inputs_len = len(encoder_data_input_fields) - dec_inputs_len = len(fast_decoder_data_input_fields) - enc_inputs = all_inputs[0:enc_inputs_len] #enc_inputs tensor - dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + - dec_inputs_len] #dec_inputs tensor - - enc_output = wrap_encoder( - src_vocab_size, - 64, ##to do !!!!!???? - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - enc_inputs, - bos_idx=bos_idx) - start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs - - def beam_search(): - max_len = layers.fill_constant( - shape=[1], - dtype=start_tokens.dtype, - value=max_out_len, - force_cpu=True) - step_idx = layers.fill_constant( - shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True) - cond = layers.less_than(x=step_idx, y=max_len) # default force_cpu=True - while_op = layers.While(cond) - # array states will be stored for each step. - ids = layers.array_write( - layers.reshape(start_tokens, (-1, 1)), step_idx) - scores = layers.array_write(init_scores, step_idx) - # cell states will be overwrited at each step. - # caches contains states of history steps in decoder self-attention - # and static encoder output projections in encoder-decoder attention - # to reduce redundant computation. - caches = [ - { - "k": # for self attention - layers.fill_constant_batch_size_like( - input=start_tokens, - shape=[-1, n_head, 0, d_key], - dtype=enc_output.dtype, - value=0), - "v": # for self attention - layers.fill_constant_batch_size_like( - input=start_tokens, - shape=[-1, n_head, 0, d_value], - dtype=enc_output.dtype, - value=0), - "static_k": # for encoder-decoder attention - layers.create_tensor(dtype=enc_output.dtype), - "static_v": # for encoder-decoder attention - layers.create_tensor(dtype=enc_output.dtype) - } for i in range(n_layer) - ] - - with while_op.block(): - pre_ids = layers.array_read(array=ids, i=step_idx) - # Since beam_search_op dosen't enforce pre_ids' shape, we can do - # inplace reshape here which actually change the shape of pre_ids. - pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True) - pre_scores = layers.array_read(array=scores, i=step_idx) - # gather cell states corresponding to selected parent - pre_src_attn_bias = layers.gather( - trg_src_attn_bias, index=parent_idx) - pre_pos = layers.elementwise_mul( - x=layers.fill_constant_batch_size_like( - input=pre_src_attn_bias, # cann't use lod tensor here - value=1, - shape=[-1, 1, 1], - dtype=pre_ids.dtype), - y=step_idx, - axis=0) - logits = wrap_decoder( - trg_vocab_size, - max_in_len, - n_layer, - n_head, - d_key, - d_value, - d_model, - d_inner_hid, - prepostprocess_dropout, - attention_dropout, - relu_dropout, - preprocess_cmd, - postprocess_cmd, - weight_sharing, - dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), - enc_output=enc_output, - caches=caches, - gather_idx=parent_idx, - bos_idx=bos_idx) - # intra-beam topK - topk_scores, topk_indices = layers.topk( - input=layers.softmax(logits), k=beam_size) - accu_scores = layers.elementwise_add( - x=layers.log(topk_scores), y=pre_scores, axis=0) - # beam_search op uses lod to differentiate branches. - accu_scores = layers.lod_reset(accu_scores, pre_ids) - # topK reduction across beams, also contain special handle of - # end beams and end sentences(batch reduction) - selected_ids, selected_scores, gather_idx = layers.beam_search( - pre_ids=pre_ids, - pre_scores=pre_scores, - ids=topk_indices, - scores=accu_scores, - beam_size=beam_size, - end_id=eos_idx, - return_parent_idx=True) - layers.increment(x=step_idx, value=1.0, in_place=True) - # cell states(caches) have been updated in wrap_decoder, - # only need to update beam search states here. - layers.array_write(selected_ids, i=step_idx, array=ids) - layers.array_write(selected_scores, i=step_idx, array=scores) - layers.assign(gather_idx, parent_idx) - layers.assign(pre_src_attn_bias, trg_src_attn_bias) - length_cond = layers.less_than(x=step_idx, y=max_len) - finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) - layers.logical_and(x=length_cond, y=finish_cond, out=cond) - - finished_ids, finished_scores = layers.beam_search_decode( - ids, scores, beam_size=beam_size, end_id=eos_idx) - return finished_ids, finished_scores - - finished_ids, finished_scores = beam_search() - return finished_ids, finished_scores, reader if use_py_reader else None