提交 9db963b0 编写于 作者: L LiuHao 提交者: pkpk

update code for v1.6 (#3621)

上级 f3a6dbbc
......@@ -29,7 +29,7 @@
1. PaddlePaddle 安装
本项目依赖于 PaddlePaddle Fluid 1.3.2 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
本项目依赖于 PaddlePaddle Fluid 1.6 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 代码安装
......@@ -42,7 +42,7 @@
3. 环境依赖
请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容
Python 2 的版本要求 2.7.15+,Python 3 的版本要求 3.5.1+/3.6/3.7,其它环境请参考 PaddlePaddle [安装说明](https://www.paddlepaddle.org.cn/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html) 部分的内容
### 代码结构说明
......
......@@ -72,12 +72,12 @@ class SentaProcessor(object):
Generate data for train, dev or infer
"""
if phase == "train":
return paddle.batch(self.get_train_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
return fluid.io.batch(self.get_train_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
#return self.get_train_examples(self.data_dir, epoch, self.max_seq_len)
elif phase == "dev":
return paddle.batch(self.get_dev_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
return fluid.io.batch(self.get_dev_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
elif phase == "infer":
return paddle.batch(self.get_test_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
return fluid.io.batch(self.get_test_examples(self.data_dir, epoch, self.max_seq_len), batch_size)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'infer'].")
......@@ -21,6 +21,7 @@ from nets import cnn_net
from nets import bilstm_net
from nets import gru_net
from models.model_check import check_cuda
from models.model_check import check_version
from config import PDConfig
import paddle
......@@ -39,11 +40,11 @@ def create_model(args,
"""
data = fluid.layers.data(
name="src_ids", shape=[-1, args.max_seq_len, 1], dtype='int64')
name="src_ids", shape=[-1, args.max_seq_len], dtype='int64')
label = fluid.layers.data(
name="label", shape=[-1, 1], dtype="int64")
seq_len = fluid.layers.data(
name="seq_len", shape=[-1, 1], dtype="int64")
name="seq_len", shape=[-1], dtype="int64")
data_reader = fluid.io.PyReader(feed_list=[data, label, seq_len],
capacity=4, iterable=False)
......
......@@ -48,16 +48,21 @@ def ernie_pyreader(args, pyreader_name):
labels = fluid.layers.data(
name="labels", shape=[-1, 1], dtype="int64")
seq_lens = fluid.layers.data(
name="seq_lens", shape=[-1, 1], dtype="int64")
pyreader = fluid.io.PyReader(feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens],
capacity=4, iterable=False)
name="seq_lens", shape=[-1], dtype="int64")
pyreader = fluid.io.DataLoader.from_generator(
feed_list=[src_ids, sent_ids, pos_ids, input_mask, labels, seq_lens],
capacity=50,
iterable=False,
use_double_buffer=True)
ernie_inputs = {
"src_ids": src_ids,
"sent_ids": sent_ids,
"pos_ids": pos_ids,
"input_mask": input_mask,
"seq_lens": seq_lens}
return pyreader, ernie_inputs, labels
def create_model(args,
......@@ -299,15 +304,15 @@ def main(args):
if args.do_train:
train_exe = exe
train_pyreader.decorate_batch_generator(train_data_generator)
train_pyreader.set_batch_generator(train_data_generator)
else:
train_exe = None
if args.do_val:
test_exe = exe
test_pyreader.decorate_batch_generator(test_data_generator)
test_pyreader.set_batch_generator(test_data_generator)
if args.do_infer:
test_exe = exe
infer_pyreader.decorate_batch_generator(infer_data_generator)
infer_pyreader.set_batch_generator(infer_data_generator)
if args.do_train:
train_pyreader.start()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册