提交 6832ca02 编写于 作者: T tink2123

update config

上级 09d8cb6d
...@@ -17,10 +17,11 @@ Global: ...@@ -17,10 +17,11 @@ Global:
average_window: 0.15 average_window: 0.15
max_average_window: 15625 max_average_window: 15625
min_average_window: 10000 min_average_window: 10000
reader_yml: ./configs/rec/rec_srn_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -118,15 +118,14 @@ class LMDBReader(object): ...@@ -118,15 +118,14 @@ class LMDBReader(object):
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if self.loss_type == 'srn': if self.loss_type == 'srn':
norm_img = process_image_srn( norm_img = process_image_srn(
img=img, img=img,
image_shape=self.image_shape, image_shape=self.image_shape,
num_heads=self.num_heads, num_heads=self.num_heads,
max_text_length=self.max_text_length max_text_length=self.max_text_length)
)
else: else:
norm_img = process_image( norm_img = process_image(
img=img, img=img,
...@@ -135,20 +134,20 @@ class LMDBReader(object): ...@@ -135,20 +134,20 @@ class LMDBReader(object):
tps=self.use_tps, tps=self.use_tps,
infer_mode=True) infer_mode=True)
yield norm_img yield norm_img
elif self.mode == 'test': #elif self.mode == 'eval':
image_file_list = get_image_file_list(self.infer_img) # image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: # for single_img in image_file_list:
img = cv2.imread(single_img) # img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: # if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image( # norm_img = process_image(
img=img, # img=img,
image_shape=self.image_shape, # image_shape=self.image_shape,
char_ops=self.char_ops, # char_ops=self.char_ops,
tps=self.use_tps, # tps=self.use_tps,
infer_mode=True # infer_mode=True
) # )
yield norm_img # yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0: if process_id == 0:
...@@ -169,14 +168,15 @@ class LMDBReader(object): ...@@ -169,14 +168,15 @@ class LMDBReader(object):
img, label = sample_info img, label = sample_info
outs = [] outs = []
if self.loss_type == "srn": if self.loss_type == "srn":
outs = process_image_srn(img, self.image_shape, self.num_heads, outs = process_image_srn(
self.max_text_length, label, img, self.image_shape, self.num_heads,
self.char_ops, self.loss_type) self.max_text_length, label, self.char_ops,
self.loss_type)
else: else:
outs = process_image(img, self.image_shape, label, outs = process_image(
self.char_ops, self.loss_type, img, self.image_shape, label, self.char_ops,
self.max_text_length) self.loss_type, self.max_text_length)
if outs is None: if outs is None:
continue continue
yield outs yield outs
...@@ -184,6 +184,7 @@ class LMDBReader(object): ...@@ -184,6 +184,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
......
...@@ -79,17 +79,45 @@ class RecModel(object): ...@@ -79,17 +79,45 @@ class RecModel(object):
feed_list = [image, label_in, label_out] feed_list = [image, label_in, label_out]
labels = {'label_in': label_in, 'label_out': label_out} labels = {'label_in': label_in, 'label_out': label_out}
elif self.loss_type == "srn": elif self.loss_type == "srn":
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64") encoder_word_pos = fluid.data(
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64") name="encoder_word_pos",
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) shape=[
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) -1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
lbl_weight = fluid.layers.data(name="lbl_weight", shape=[-1, 1], dtype='int64') 1
],
dtype="int64")
gsrm_word_pos = fluid.data(
name="gsrm_word_pos",
shape=[-1, self.max_text_length, 1],
dtype="int64")
gsrm_slf_attn_bias1 = fluid.data(
name="gsrm_slf_attn_bias1",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
lbl_weight = fluid.layers.data(
name="lbl_weight", shape=[-1, 1], dtype='int64')
label = fluid.data( label = fluid.data(
name='label', shape=[-1, 1], dtype='int32', lod_level=1) name='label', shape=[-1, 1], dtype='int32', lod_level=1)
feed_list = [image, label, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight] feed_list = [
labels = {'label': label, 'encoder_word_pos': encoder_word_pos, image, label, encoder_word_pos, gsrm_word_pos,
'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,'lbl_weight':lbl_weight} ]
labels = {
'label': label,
'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,
'lbl_weight': lbl_weight
}
else: else:
label = fluid.data( label = fluid.data(
name='label', shape=[None, 1], dtype='int32', lod_level=1) name='label', shape=[None, 1], dtype='int32', lod_level=1)
...@@ -114,13 +142,39 @@ class RecModel(object): ...@@ -114,13 +142,39 @@ class RecModel(object):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "srn": if self.loss_type == "srn":
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64") encoder_word_pos = fluid.data(
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64") name="encoder_word_pos",
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) shape=[
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) -1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
feed_list = [image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] 1
labels = {'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos, ],
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2} dtype="int64")
gsrm_word_pos = fluid.data(
name="gsrm_word_pos",
shape=[-1, self.max_text_length, 1],
dtype="int64")
gsrm_slf_attn_bias1 = fluid.data(
name="gsrm_slf_attn_bias1",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = {
'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
}
return image, labels, loader return image, labels, loader
def __call__(self, mode): def __call__(self, mode):
...@@ -140,8 +194,13 @@ class RecModel(object): ...@@ -140,8 +194,13 @@ class RecModel(object):
label = labels['label'] label = labels['label']
if self.loss_type == 'srn': if self.loss_type == 'srn':
total_loss, img_loss, word_loss = self.loss(predicts, labels) total_loss, img_loss, word_loss = self.loss(predicts, labels)
outputs = {'total_loss':total_loss, 'img_loss':img_loss, 'word_loss':word_loss, outputs = {
'decoded_out':decoded_out, 'label':label} 'total_loss': total_loss,
'img_loss': img_loss,
'word_loss': word_loss,
'decoded_out': decoded_out,
'label': label
}
else: else:
outputs = {'total_loss':loss, 'decoded_out':\ outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label} decoded_out, 'label':label}
......
...@@ -4,8 +4,9 @@ import numpy as np ...@@ -4,8 +4,9 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from .desc import * # Set seed for CE
from .config import ModelHyperParams,TrainTaskConfig dropout_seed = None
def wrap_layer_with_block(layer, block_idx): def wrap_layer_with_block(layer, block_idx):
""" """
...@@ -269,7 +270,8 @@ pre_process_layer = partial(pre_post_process_layer, None) ...@@ -269,7 +270,8 @@ pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer post_process_layer = pre_post_process_layer
def prepare_encoder(src_word,#[b,t,c] def prepare_encoder(
src_word, #[b,t,c]
src_pos, src_pos,
src_vocab_size, src_vocab_size,
src_emb_dim, src_emb_dim,
...@@ -284,8 +286,8 @@ def prepare_encoder(src_word,#[b,t,c] ...@@ -284,8 +286,8 @@ def prepare_encoder(src_word,#[b,t,c]
This module is used at the bottom of the encoder stacks. This module is used at the bottom of the encoder stacks.
""" """
src_word_emb =src_word#layers.concat(res,axis=1) src_word_emb = src_word #layers.concat(res,axis=1)
src_word_emb=layers.cast(src_word_emb,'float32') src_word_emb = layers.cast(src_word_emb, 'float32')
# print("src_word_emb",src_word_emb) # print("src_word_emb",src_word_emb)
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
...@@ -323,7 +325,7 @@ def prepare_decoder(src_word, ...@@ -323,7 +325,7 @@ def prepare_decoder(src_word,
name=word_emb_param_name, name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
# print("target_word_emb",src_word_emb) # print("target_word_emb",src_word_emb)
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5) src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding( src_pos_enc = layers.embedding(
src_pos, src_pos,
size=[src_max_len, src_emb_dim], size=[src_max_len, src_emb_dim],
...@@ -335,6 +337,7 @@ def prepare_decoder(src_word, ...@@ -335,6 +337,7 @@ def prepare_decoder(src_word,
enc_input, dropout_prob=dropout_rate, seed=dropout_seed, enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
is_test=False) if dropout_rate else enc_input is_test=False) if dropout_rate else enc_input
# prepare_encoder = partial( # prepare_encoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0]) # prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
# prepare_decoder = partial( # prepare_decoder = partial(
...@@ -595,21 +598,9 @@ def transformer(src_vocab_size, ...@@ -595,21 +598,9 @@ def transformer(src_vocab_size,
weights = all_inputs[-1] weights = all_inputs[-1]
enc_output = wrap_encoder( enc_output = wrap_encoder(
src_vocab_size, src_vocab_size, 64, n_layer, n_head, d_key, d_value, d_model,
ModelHyperParams.src_seq_len, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
n_layer, preprocess_cmd, postprocess_cmd, weight_sharing, enc_inputs)
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs)
predict = wrap_decoder( predict = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -676,8 +667,8 @@ def wrap_encoder_forFeature(src_vocab_size, ...@@ -676,8 +667,8 @@ def wrap_encoder_forFeature(src_vocab_size,
conv_features, src_pos, src_slf_attn_bias = make_all_inputs( conv_features, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields) encoder_data_input_fields)
else: else:
conv_features, src_pos, src_slf_attn_bias = enc_inputs# conv_features, src_pos, src_slf_attn_bias = enc_inputs #
b,t,c = conv_features.shape b, t, c = conv_features.shape
#""" #"""
# insert cnn # insert cnn
#""" #"""
...@@ -718,7 +709,7 @@ def wrap_encoder_forFeature(src_vocab_size, ...@@ -718,7 +709,7 @@ def wrap_encoder_forFeature(src_vocab_size,
max_length, max_length,
prepostprocess_dropout, prepostprocess_dropout,
bos_idx=bos_idx, bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0]) word_emb_param_name="src_word_emb_table")
enc_output = encoder( enc_output = encoder(
enc_input, enc_input,
...@@ -736,6 +727,7 @@ def wrap_encoder_forFeature(src_vocab_size, ...@@ -736,6 +727,7 @@ def wrap_encoder_forFeature(src_vocab_size,
postprocess_cmd, ) postprocess_cmd, )
return enc_output return enc_output
def wrap_encoder(src_vocab_size, def wrap_encoder(src_vocab_size,
max_length, max_length,
n_layer, n_layer,
...@@ -762,7 +754,7 @@ def wrap_encoder(src_vocab_size, ...@@ -762,7 +754,7 @@ def wrap_encoder(src_vocab_size,
src_word, src_pos, src_slf_attn_bias = make_all_inputs( src_word, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields) encoder_data_input_fields)
else: else:
src_word, src_pos, src_slf_attn_bias = enc_inputs# src_word, src_pos, src_slf_attn_bias = enc_inputs #
#""" #"""
# insert cnn # insert cnn
#""" #"""
...@@ -802,7 +794,7 @@ def wrap_encoder(src_vocab_size, ...@@ -802,7 +794,7 @@ def wrap_encoder(src_vocab_size,
max_length, max_length,
prepostprocess_dropout, prepostprocess_dropout,
bos_idx=bos_idx, bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0]) word_emb_param_name="src_word_emb_table")
enc_output = encoder( enc_output = encoder(
enc_input, enc_input,
...@@ -858,8 +850,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -858,8 +850,8 @@ def wrap_decoder(trg_vocab_size,
max_length, max_length,
prepostprocess_dropout, prepostprocess_dropout,
bos_idx=bos_idx, bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0] word_emb_param_name="src_word_emb_table"
if weight_sharing else word_emb_param_names[1]) if weight_sharing else "trg_word_emb_table")
dec_output = decoder( dec_output = decoder(
dec_input, dec_input,
enc_output, enc_output,
...@@ -886,7 +878,7 @@ def wrap_decoder(trg_vocab_size, ...@@ -886,7 +878,7 @@ def wrap_decoder(trg_vocab_size,
predict = layers.matmul( predict = layers.matmul(
x=dec_output, x=dec_output,
y=fluid.default_main_program().global_block().var( y=fluid.default_main_program().global_block().var(
word_emb_param_names[0]), "trg_word_emb_table"),
transpose_y=True) transpose_y=True)
else: else:
predict = layers.fc(input=dec_output, predict = layers.fc(input=dec_output,
...@@ -931,12 +923,13 @@ def fast_decode(src_vocab_size, ...@@ -931,12 +923,13 @@ def fast_decode(src_vocab_size,
enc_inputs_len = len(encoder_data_input_fields) enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields) dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]#enc_inputs tensor enc_inputs = all_inputs[0:enc_inputs_len] #enc_inputs tensor
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]#dec_inputs tensor dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len +
dec_inputs_len] #dec_inputs tensor
enc_output = wrap_encoder( enc_output = wrap_encoder(
src_vocab_size, src_vocab_size,
ModelHyperParams.src_seq_len,##to do !!!!!???? 64, ##to do !!!!!????
n_layer, n_layer,
n_head, n_head,
d_key, d_key,
......
...@@ -75,7 +75,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -75,7 +75,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
preds_lod = outs[0].lod()[0] preds_lod = outs[0].lod()[0]
labels, labels_lod = convert_rec_label_to_lod(label_list) labels, labels_lod = convert_rec_label_to_lod(label_list)
acc, acc_num, sample_num = cal_predicts_accuracy( acc, acc_num, sample_num = cal_predicts_accuracy(
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate) char_ops, preds, preds_lod, labels, labels_lod,
is_remove_duplicate)
else: else:
encoder_word_pos_list = [] encoder_word_pos_list = []
gsrm_word_pos_list = [] gsrm_word_pos_list = []
...@@ -89,10 +90,14 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -89,10 +90,14 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
img_list = np.concatenate(img_list, axis=0) img_list = np.concatenate(img_list, axis=0)
label_list = np.concatenate(label_list, axis=0) label_list = np.concatenate(label_list, axis=0)
encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64) encoder_word_pos_list = np.concatenate(
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64) encoder_word_pos_list, axis=0).astype(np.int64)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32) gsrm_word_pos_list = np.concatenate(
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32) gsrm_word_pos_list, axis=0).astype(np.int64)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
labels = label_list labels = label_list
...@@ -108,7 +113,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -108,7 +113,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_acc_num += acc_num total_acc_num += acc_num
total_sample_num += sample_num total_sample_num += sample_num
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc)) #logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num += 1 total_batch_num += 1
avg_acc = total_acc_num * 1.0 / total_sample_num avg_acc = total_acc_num * 1.0 / total_sample_num
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \ metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
......
...@@ -34,6 +34,7 @@ from ppocr.utils.save_load import save_model ...@@ -34,6 +34,7 @@ from ppocr.utils.save_load import save_model
import numpy as np import numpy as np
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
def __init__(self): def __init__(self):
super(ArgsParser, self).__init__( super(ArgsParser, self).__init__(
...@@ -196,10 +197,13 @@ def build(config, main_prog, startup_prog, mode): ...@@ -196,10 +197,13 @@ def build(config, main_prog, startup_prog, mode):
if config['Global']["loss_type"] == 'srn': if config['Global']["loss_type"] == 'srn':
model_average = fluid.optimizer.ModelAverage( model_average = fluid.optimizer.ModelAverage(
config['Global']['average_window'], config['Global']['average_window'],
min_average_window=config['Global']['min_average_window'], min_average_window=config['Global'][
max_average_window=config['Global']['max_average_window']) 'min_average_window'],
max_average_window=config['Global'][
'max_average_window'])
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,model_average) return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,
model_average)
def build_export(config, main_prog, startup_prog): def build_export(config, main_prog, startup_prog):
...@@ -398,6 +402,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): ...@@ -398,6 +402,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
return return
def preprocess(): def preprocess():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
...@@ -409,8 +414,8 @@ def preprocess(): ...@@ -409,8 +414,8 @@ def preprocess():
check_gpu(use_gpu) check_gpu(use_gpu)
alg = config['Global']['algorithm'] alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global']) config['Global']['char_ops'] = CharacterOps(config['Global'])
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
......
/workspace/PaddleOCR/train_data/
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册