未验证 提交 fc437d49 编写于 作者: W whs 提交者: GitHub

Change layers.data to fluid.data in ocr model (#3594)

1. Change layers.data to fluid.data
2. Add check_version in ocr model
上级 9e89cba0
>注意:在paddle1.5版本上训练attention model有收敛问题,建议您暂时使用paddle1.4版本,后续我们会修复该问题。 >注意:在paddle1.5版本上训练attention model有收敛问题,建议您暂时使用paddle1.4版本和models1.5分支,后续我们会修复该问题。
## 代码结构 ## 代码结构
``` ```
......
...@@ -165,11 +165,11 @@ def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj, ...@@ -165,11 +165,11 @@ def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
def attention_train_net(args, data_shape, num_classes): def attention_train_net(args, data_shape, num_classes):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.data(name='pixel', shape=[None] + data_shape, dtype='float32')
label_in = fluid.layers.data( label_in = fluid.data(
name='label_in', shape=[1], dtype='int32', lod_level=1) name='label_in', shape=[None, 1], dtype='int32', lod_level=1)
label_out = fluid.layers.data( label_out = fluid.data(
name='label_out', shape=[1], dtype='int32', lod_level=1) name='label_out', shape=[None, 1], dtype='int32', lod_level=1)
gru_backward, encoded_vector, encoded_proj = encoder_net(images) gru_backward, encoded_vector, encoded_proj = encoder_net(images)
...@@ -264,10 +264,10 @@ def attention_infer(images, num_classes, use_cudnn=True): ...@@ -264,10 +264,10 @@ def attention_infer(images, num_classes, use_cudnn=True):
ids_array = fluid.layers.create_array('int64') ids_array = fluid.layers.create_array('int64')
scores_array = fluid.layers.create_array('float32') scores_array = fluid.layers.create_array('float32')
init_ids = fluid.layers.data( init_ids = fluid.data(
name="init_ids", shape=[1], dtype="int64", lod_level=2) name="init_ids", shape=[None, 1], dtype="int64", lod_level=2)
init_scores = fluid.layers.data( init_scores = fluid.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2) name="init_scores", shape=[None, 1], dtype="float32", lod_level=2)
fluid.layers.array_write(init_ids, array=ids_array, i=counter) fluid.layers.array_write(init_ids, array=ids_array, i=counter)
fluid.layers.array_write(init_scores, array=scores_array, i=counter) fluid.layers.array_write(init_scores, array=scores_array, i=counter)
...@@ -349,11 +349,11 @@ def attention_infer(images, num_classes, use_cudnn=True): ...@@ -349,11 +349,11 @@ def attention_infer(images, num_classes, use_cudnn=True):
def attention_eval(data_shape, num_classes, use_cudnn=True): def attention_eval(data_shape, num_classes, use_cudnn=True):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.data(name='pixel', shape=[None] + data_shape, dtype='float32')
label_in = fluid.layers.data( label_in = fluid.data(
name='label_in', shape=[1], dtype='int32', lod_level=1) name='label_in', shape=[None, 1], dtype='int32', lod_level=1)
label_out = fluid.layers.data( label_out = fluid.data(
name='label_out', shape=[1], dtype='int32', lod_level=1) name='label_out', shape=[None, 1], dtype='int32', lod_level=1)
label_out = fluid.layers.cast(x=label_out, dtype='int64') label_out = fluid.layers.cast(x=label_out, dtype='int64')
label_in = fluid.layers.cast(x=label_in, dtype='int64') label_in = fluid.layers.cast(x=label_in, dtype='int64')
......
...@@ -190,9 +190,9 @@ def ctc_train_net(args, data_shape, num_classes): ...@@ -190,9 +190,9 @@ def ctc_train_net(args, data_shape, num_classes):
learning_rate_decay = None learning_rate_decay = None
regularizer = fluid.regularizer.L2Decay(L2_RATE) regularizer = fluid.regularizer.L2Decay(L2_RATE)
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.data(name='pixel', shape=[None] + data_shape, dtype='float32')
label = fluid.layers.data( label = fluid.data(
name='label', shape=[1], dtype='int32', lod_level=1) name='label', shape=[None, 1], dtype='int32', lod_level=1)
fc_out = encoder_net( fc_out = encoder_net(
images, images,
num_classes, num_classes,
......
...@@ -32,7 +32,7 @@ except NameError: ...@@ -32,7 +32,7 @@ except NameError:
SOS = 0 SOS = 0
EOS = 1 EOS = 1
NUM_CLASSES = 95 NUM_CLASSES = 95
DATA_SHAPE = [1, 48, 512] DATA_SHAPE = [1, 48, None]
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5" DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz" DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data
from utility import check_gpu from utility import check_gpu, check_version
from attention_model import attention_eval from attention_model import attention_eval
from crnn_ctc_model import ctc_eval from crnn_ctc_model import ctc_eval
import data_reader import data_reader
...@@ -85,6 +85,7 @@ def main(): ...@@ -85,6 +85,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
check_gpu(args.use_gpu) check_gpu(args.use_gpu)
check_version()
evaluate(args) evaluate(args)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer, get_ctc_feeder_for_infer from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer, get_ctc_feeder_for_infer
from utility import check_gpu from utility import check_gpu, check_version
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
from crnn_ctc_model import ctc_infer from crnn_ctc_model import ctc_infer
from attention_model import attention_infer from attention_model import attention_infer
...@@ -153,6 +153,7 @@ def main(): ...@@ -153,6 +153,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
check_gpu(args.use_gpu) check_gpu(args.use_gpu)
check_version()
if args.profile: if args.profile:
if args.use_gpu: if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof: with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
......
...@@ -17,7 +17,7 @@ from __future__ import division ...@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data
from utility import check_gpu from utility import check_gpu, check_version
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
from crnn_ctc_model import ctc_train_net from crnn_ctc_model import ctc_train_net
from attention_model import attention_train_net from attention_model import attention_train_net
...@@ -230,6 +230,7 @@ def main(): ...@@ -230,6 +230,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
check_gpu(args.use_gpu) check_gpu(args.use_gpu)
check_version()
if args.profile: if args.profile:
if args.use_gpu: if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof: with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
......
...@@ -159,3 +159,18 @@ def check_gpu(use_gpu): ...@@ -159,3 +159,18 @@ def check_gpu(use_gpu):
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册