提交 74f6a5cb 编写于 作者: W WenmuZhou

merge paddleocr

...@@ -189,7 +189,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训 ...@@ -189,7 +189,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训
请扫描下面二维码,完成问卷填写,获取加群二维码和OCR方向的炼丹秘籍 请扫描下面二维码,完成问卷填写,获取加群二维码和OCR方向的炼丹秘籍
<div align="center"> <div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" /> <img src="./doc/joinus.PNG" width = "200" height = "200" />
</div> </div>
<a name="许可证书"></a> <a name="许可证书"></a>
......
...@@ -56,7 +56,6 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr ...@@ -56,7 +56,6 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
- Algorithm introduction - Algorithm introduction
- [Text Detection Algorithm](#TEXTDETECTIONALGORITHM) - [Text Detection Algorithm](#TEXTDETECTIONALGORITHM)
- [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM) - [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM)
- [END-TO-END OCR Algorithm](#ENDENDOCRALGORITHM)
- Model training/evaluation - Model training/evaluation
- [Text Detection](./doc/doc_en/detection_en.md) - [Text Detection](./doc/doc_en/detection_en.md)
- [Text Recognition](./doc/doc_en/recognition_en.md) - [Text Recognition](./doc/doc_en/recognition_en.md)
...@@ -158,10 +157,6 @@ We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/ ...@@ -158,10 +157,6 @@ We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md) Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
<a name="ENDENDOCRALGORITHM"></a>
## END-TO-END OCR Algorithm
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, coming soon)
## Visualization ## Visualization
<a name="UCOCRVIS"></a> <a name="UCOCRVIS"></a>
...@@ -211,7 +206,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn ...@@ -211,7 +206,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn
Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group. Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group.
<div align="center"> <div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" /> <img src="./doc/joinus.PNG" width = "200" height = "200" />
</div> </div>
<a name="LICENSE"></a> <a name="LICENSE"></a>
......
...@@ -140,7 +140,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入 ...@@ -140,7 +140,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) 训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
*由于OpenCV的兼容性问题,扰动操作暂时只支持GPU* *由于OpenCV的兼容性问题,扰动操作暂时只支持Linux*
- 训练 - 训练
......
...@@ -61,6 +61,14 @@ hub install deploy\hubserving\ocr_rec\ ...@@ -61,6 +61,14 @@ hub install deploy\hubserving\ocr_rec\
hub install deploy\hubserving\ocr_system\ hub install deploy\hubserving\ocr_system\
``` ```
#### 安装模型
安装服务模块前,需要将训练好的模型放到对应的文件夹内。默认使用的是:
./inference/ch_det_mv3_db/
./inference/ch_rec_mv3_crnn/
这两个模型可以在https://github.com/PaddlePaddle/PaddleOCR 下载
可以在./deploy/hubserving/ocr_system/params.py 里面修改成自己的模型
### 3. 启动服务 ### 3. 启动服务
#### 方式1. 命令行命令启动(仅支持CPU) #### 方式1. 命令行命令启动(仅支持CPU)
**启动命令:** **启动命令:**
......
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first:
a. Docker
b. Graphics driver and CUDA 10.0+(GPU)
c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU)
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword)
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service status(If you can see the following statement then it means completed:Successfully installed ocr_system && Running on http://0.0.0.0:8866/)
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/)
b. Post a service request(sample request in sample_request.txt)
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposne(If the call is successful, the following result will be returned)
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -121,24 +121,22 @@ def RandomCropData(data, size): ...@@ -121,24 +121,22 @@ def RandomCropData(data, size):
all_care_polys = [ all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
] ]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys, crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys,
min_crop_side_ratio, max_tries) min_crop_side_ratio, max_tries)
# crop 图片 保持比例填充 dh, dw = size
scale_w = size[0] / crop_w scale_w = dw / crop_w
scale_h = size[1] / crop_h scale_h = dh / crop_h
scale = min(scale_w, scale_h) scale = min(scale_w, scale_h)
h = int(crop_h * scale) h = int(crop_h * scale)
w = int(crop_w * scale) w = int(crop_w * scale)
if keep_ratio: if keep_ratio:
padimg = np.zeros((size[1], size[0], im.shape[2]), im.dtype) padimg = np.zeros((dh, dw, im.shape[2]), im.dtype)
padimg[:h, :w] = cv2.resize( padimg[:h, :w] = cv2.resize(
im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg img = padimg
else: else:
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(size)) (dw, dh))
# crop 文本框
text_polys_crop = [] text_polys_crop = []
ignore_tags_crop = [] ignore_tags_crop = []
texts_crop = [] texts_crop = []
......
...@@ -136,7 +136,7 @@ class RecModel(object): ...@@ -136,7 +136,7 @@ class RecModel(object):
else: else:
labels = None labels = None
loader = None loader = None
if self.char_type == "ch" and self.infer_img: if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
image_shape[-1] = -1 image_shape[-1] = -1
if self.tps != None: if self.tps != None:
logger.info( logger.info(
...@@ -172,16 +172,13 @@ class RecModel(object): ...@@ -172,16 +172,13 @@ class RecModel(object):
self.max_text_length self.max_text_length
], ],
dtype="float32") dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = { labels = {
'encoder_word_pos': encoder_word_pos, 'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos, 'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
} }
return image, labels, loader return image, labels, loader
def __call__(self, mode): def __call__(self, mode):
...@@ -218,8 +215,13 @@ class RecModel(object): ...@@ -218,8 +215,13 @@ class RecModel(object):
if self.loss_type == "ctc": if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
if self.loss_type == "srn": if self.loss_type == "srn":
raise Exception( return [
"Warning! SRN does not support export model currently") image, labels, {
'decoded_out': decoded_out,
'predicts': predict
}
]
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
predict = predicts['predict'] predict = predicts['predict']
......
...@@ -40,6 +40,7 @@ class TextRecognizer(object): ...@@ -40,6 +40,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
self.use_zero_copy_run = args.use_zero_copy_run self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = { char_ops_params = {
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
...@@ -47,12 +48,15 @@ class TextRecognizer(object): ...@@ -47,12 +48,15 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char, "use_space_char": args.use_space_char,
"max_text_length": args.max_text_length "max_text_length": args.max_text_length
} }
if self.rec_algorithm != "RARE": if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
char_ops_params['loss_type'] = 'ctc' char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc' self.loss_type = 'ctc'
else: elif self.rec_algorithm == "RARE":
char_ops_params['loss_type'] = 'attention' char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention' self.loss_type = 'attention'
elif self.rec_algorithm == "SRN":
char_ops_params['loss_type'] = 'srn'
self.loss_type = 'srn'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -75,6 +79,83 @@ class TextRecognizer(object): ...@@ -75,6 +79,83 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
char_num):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self,
img,
image_shape,
num_heads,
max_text_length,
char_ops=None):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
char_num = char_ops.get_char_num()
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list): def __call__(self, img_list):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
...@@ -84,7 +165,7 @@ class TextRecognizer(object): ...@@ -84,7 +165,7 @@ class TextRecognizer(object):
# Sorting can speed up the recognition process # Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list)) indices = np.argsort(np.array(width_list))
# rec_res = [] #rec_res = []
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
predict_time = 0 predict_time = 0
...@@ -98,20 +179,62 @@ class TextRecognizer(object): ...@@ -98,20 +179,62 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]], norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio) max_wh_ratio)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch) else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
25, self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time()
if self.use_zero_copy_run: if self.loss_type == "srn":
self.input_tensor.copy_from_cpu(norm_img_batch) starttime = time.time()
self.predictor.zero_copy_run() encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
else: gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch]) encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list,
gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
gsrm_word_pos_list
]
self.predictor.run(inputs)
else:
starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
...@@ -136,6 +259,26 @@ class TextRecognizer(object): ...@@ -136,6 +259,26 @@ class TextRecognizer(object):
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
# rec_res.append([preds_text, score]) # rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score] rec_res[indices[beg_img_no + rno]] = [preds_text, score]
elif self.loss_type == 'srn':
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
probs = self.output_tensors[1].copy_to_cpu()
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
else: else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu()
...@@ -170,6 +313,7 @@ def main(args): ...@@ -170,6 +313,7 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e: except Exception as e:
......
...@@ -125,7 +125,8 @@ def create_predictor(args, mode): ...@@ -125,7 +125,8 @@ def create_predictor(args, mode):
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
......
...@@ -145,7 +145,7 @@ def main(): ...@@ -145,7 +145,7 @@ def main():
preds = preds.reshape(-1) preds = preds.reshape(-1)
probs = np.array(predict[1]) probs = np.array(predict[1])
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num-1))[0] valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
continue continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
......
...@@ -209,10 +209,19 @@ def build_export(config, main_prog, startup_prog): ...@@ -209,10 +209,19 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
image, outputs = model(mode='export') algorithm = config['Global']['algorithm']
if algorithm == "SRN":
image, others, outputs = model(mode='export')
else:
image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name] if algorithm == "SRN":
others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names
else:
feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name return feeded_var_names, target_vars, fetches_var_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册