未验证 提交 35a81579 编写于 作者: G Guo Sheng 提交者: GitHub

Update Transformer using Paddle-1.6 apis. (#3614)

* Update Transformer using Paddle-1.6 apis.

* Add check_version for Paddle-1.6 and reorganize utils including check, configuration, etc.
上级 f37e4f8e
......@@ -5,9 +5,8 @@
```text
.
├── images # README 文档中的图片
├── palm # 工具包
├── utils # 工具包
├── desc.py # 输入描述文件
├── dist_utils.py # 多进程训练工具
├── gen_data.sh # 数据生成脚本
├── inference_model.py # 保存 inference_model 的脚本
├── main.py # 主程序入口
......@@ -33,7 +32,7 @@
1. paddle安装
本项目依赖于 PaddlePaddle Fluid 1.5.0 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
本项目依赖于 PaddlePaddle 1.6及以上版本或适当的develop版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 下载代码
......
......@@ -15,9 +15,9 @@
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
batch_size = None
# The placeholder for squence length in compile time.
seq_len = 256
seq_len = None
# The placeholder for head number in compile time.
n_head = 8
# The placeholder for model dim in compile time.
......@@ -27,11 +27,11 @@ d_model = 512
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
"src_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
......@@ -39,11 +39,11 @@ input_descs = {
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
"trg_word": [(batch_size, seq_len), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
"trg_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
......@@ -60,11 +60,11 @@ input_descs = {
"enc_output": [(batch_size, seq_len, d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
"lbl_word": [(None, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
"lbl_weight": [(None, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
......
......@@ -22,9 +22,9 @@ import numpy as np
import paddle
import paddle.fluid as fluid
#include palm for easier nlp coding
from palm.toolkit.input_field import InputField
from palm.toolkit.configure import PDConfig
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import desc
......@@ -149,7 +149,6 @@ def do_predict(args):
is_training=False, model_input=input_field, args=args)
out_ids, out_scores = predictions
out_ids.persistable = out_scores.persistable = True
# This is used here to set dropout to the test mode.
test_prog = test_prog.clone(for_test=True)
......@@ -185,8 +184,8 @@ def do_predict(args):
f = open(args.output_file, "wb")
# start predicting
## decorate the pyreader with batch_generator
input_field.reader.decorate_batch_generator(batch_generator)
input_field.reader.start()
input_field.loader.set_batch_generator(batch_generator)
input_field.loader.start()
while True:
try:
seq_ids, seq_scores = exe.run(
......@@ -231,5 +230,7 @@ if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_predict(args)
......@@ -84,12 +84,12 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
......@@ -103,6 +103,8 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
lbl_word = lbl_word.reshape(-1, 1)
lbl_weight = lbl_weight.reshape(-1, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
......@@ -122,9 +124,9 @@ def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head, place):
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word = trg_word.reshape(-1, 1)
src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len)
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
......
......@@ -22,13 +22,13 @@ import numpy as np
import paddle
import paddle.fluid as fluid
#include palm for easier nlp coding
from palm.toolkit.input_field import InputField
from palm.toolkit.configure import PDConfig
import utils.dist_utils as dist_utils
from utils.input_field import InputField
from utils.configure import PDConfig
from utils.check import check_gpu, check_version
# include task-specific libs
import desc
import dist_utils
import reader
from transformer import create_net, position_encoding_init
......@@ -188,8 +188,6 @@ def do_train(args):
sum_cost, avg_cost, token_num = create_net(
is_training=True, model_input=input_field, args=args)
sum_cost.persistable = avg_cost.persistable = token_num.persistable = True
# define the optimizer
with fluid.default_main_program()._lr_schedule_guard():
......@@ -206,7 +204,7 @@ def do_train(args):
# prepare training
## decorate the pyreader with batch_generator
input_field.reader.decorate_batch_generator(batch_generator)
input_field.loader.set_batch_generator(batch_generator)
## define the executor and program for training
......@@ -254,7 +252,7 @@ def do_train(args):
step_idx = 0
for pass_id in range(args.epoch):
pass_start_time = time.time()
input_field.reader.start()
input_field.loader.start()
batch_id = 0
while True:
......@@ -303,7 +301,7 @@ def do_train(args):
step_idx += 1
except fluid.core.EOFException:
input_field.reader.reset()
input_field.loader.reset()
break
time_consumed = time.time() - pass_start_time
......@@ -323,5 +321,7 @@ if __name__ == "__main__":
args = PDConfig(yaml_file="./transformer.yaml")
args.build()
args.Print()
check_gpu(args.use_cuda)
check_version()
do_train(args)
......@@ -297,20 +297,19 @@ def prepare_encoder_decoder(src_word,
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word_emb = fluid.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
param_attr=fluid.ParamAttr(name=word_emb_param_name,
initializer=fluid.initializer.Normal(
0., src_emb_dim**-0.5)))
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
src_pos_enc = fluid.embedding(src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
......@@ -506,38 +505,8 @@ def decoder(dec_input,
return dec_output
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def make_all_py_reader_inputs(input_fields, is_test=False):
reader = layers.py_reader(
capacity=20,
name="test_reader" if is_test else "train_reader",
shapes=[input_descs[input_field][0] for input_field in input_fields],
dtypes=[input_descs[input_field][1] for input_field in input_fields],
lod_levels=[
input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0
for input_field in input_fields
])
return layers.read_file(reader), reader
def transformer(src_vocab_size,
def transformer(model_input,
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
......@@ -554,96 +523,76 @@ def transformer(src_vocab_size,
weight_sharing,
label_smooth_eps,
bos_idx=0,
use_py_reader=False,
is_test=False,
model_input=None):
is_test=False):
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
if model_input:
enc_inputs = (model_input.src_word, model_input.src_pos,
model_input.src_slf_attn_bias)
dec_inputs = (model_input.trg_word, model_input.trg_pos,
model_input.trg_slf_attn_bias,
model_input.trg_src_attn_bias)
label = model_input.lbl_word
weights = model_input.lbl_weight
else:
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names,
is_test)
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
label = all_inputs[-2]
weights = all_inputs[-1]
enc_output = wrap_encoder(
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs,
bos_idx=bos_idx)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs,
enc_output, )
enc_inputs = (model_input.src_word, model_input.src_pos,
model_input.src_slf_attn_bias)
dec_inputs = (model_input.trg_word, model_input.trg_pos,
model_input.trg_slf_attn_bias, model_input.trg_src_attn_bias)
label = model_input.lbl_word
weights = model_input.lbl_weight
enc_output = wrap_encoder(enc_inputs,
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
bos_idx=bos_idx)
predict = wrap_decoder(dec_inputs,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_output=enc_output)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if label_smooth_eps:
label = layers.label_smooth(
label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
# TODO: use fluid.input.one_hot after softmax_with_cross_entropy removing
# the enforcement that the last dimension of label must be 1.
label = layers.label_smooth(label=layers.one_hot(input=label,
depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
weighted_cost = layers.elementwise_mul(x=cost, y=weights, axis=0)
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None
return sum_cost, avg_cost, predict, token_num
def wrap_encoder(src_vocab_size,
def wrap_encoder(enc_inputs,
src_vocab_size,
max_length,
n_layer,
n_head,
......@@ -657,17 +606,11 @@ def wrap_encoder(src_vocab_size,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs=None,
bos_idx=0):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields)
else:
src_word, src_pos, src_slf_attn_bias = enc_inputs
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
......@@ -694,7 +637,8 @@ def wrap_encoder(src_vocab_size,
return enc_output
def wrap_decoder(trg_vocab_size,
def wrap_decoder(dec_inputs,
trg_vocab_size,
max_length,
n_layer,
n_head,
......@@ -708,7 +652,6 @@ def wrap_decoder(trg_vocab_size,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=None,
enc_output=None,
caches=None,
gather_idx=None,
......@@ -716,12 +659,7 @@ def wrap_decoder(trg_vocab_size,
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
make_all_inputs(decoder_data_input_fields)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
......@@ -770,66 +708,36 @@ def wrap_decoder(trg_vocab_size,
return predict
def fast_decode(src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beam_size,
max_out_len,
bos_idx,
eos_idx,
use_py_reader=False,
model_input=None):
def fast_decode(model_input, src_vocab_size, trg_vocab_size, max_in_len,
n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd, weight_sharing, beam_size,
max_out_len, bos_idx, eos_idx):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
if model_input:
enc_inputs = (model_input.src_word, model_input.src_pos,
model_input.src_slf_attn_bias)
dec_inputs = (model_input.trg_word, model_input.init_score,
model_input.init_idx, model_input.trg_src_attn_bias)
else:
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names)
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
enc_output = wrap_encoder(
src_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs,
bos_idx=bos_idx)
enc_inputs = (model_input.src_word, model_input.src_pos,
model_input.src_slf_attn_bias)
dec_inputs = (model_input.trg_word, model_input.init_score,
model_input.init_idx, model_input.trg_src_attn_bias)
enc_output = wrap_encoder(enc_inputs,
src_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
bos_idx=bos_idx)
start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs
def beam_search():
......@@ -875,7 +783,7 @@ def fast_decode(src_vocab_size,
pre_ids = layers.array_read(array=ids, i=step_idx)
# Since beam_search_op dosen't enforce pre_ids' shape, we can do
# inplace reshape here which actually change the shape of pre_ids.
pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
# pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
pre_scores = layers.array_read(array=scores, i=step_idx)
# gather cell states corresponding to selected parent
pre_src_attn_bias = layers.gather(
......@@ -884,30 +792,29 @@ def fast_decode(src_vocab_size,
x=layers.fill_constant_batch_size_like(
input=pre_src_attn_bias, # cann't use lod tensor here
value=1,
shape=[-1, 1, 1],
shape=[-1, 1],
dtype=pre_ids.dtype),
y=step_idx,
axis=0)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=enc_output,
caches=caches,
gather_idx=parent_idx,
bos_idx=bos_idx)
logits = wrap_decoder((pre_ids, pre_pos, None, pre_src_attn_bias),
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_output=enc_output,
caches=caches,
gather_idx=parent_idx,
bos_idx=bos_idx)
# intra-beam topK
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
......@@ -941,51 +848,26 @@ def fast_decode(src_vocab_size,
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores, reader if use_py_reader else None
return finished_ids, finished_scores
def create_net(is_training, model_input, args):
if is_training:
sum_cost, avg_cost, _, token_num, _ = transformer(
args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.label_smooth_eps,
args.bos_idx,
model_input=model_input)
sum_cost, avg_cost, _, token_num = transformer(
model_input, args.src_vocab_size, args.trg_vocab_size,
args.max_length + 1, args.n_layer, args.n_head, args.d_key,
args.d_value, args.d_model, args.d_inner_hid,
args.prepostprocess_dropout, args.attention_dropout,
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.label_smooth_eps, args.bos_idx)
return sum_cost, avg_cost, token_num
else:
out_ids, out_scores, _ = fast_decode(
args.src_vocab_size,
args.trg_vocab_size,
args.max_length + 1,
args.n_layer,
args.n_head,
args.d_key,
args.d_value,
args.d_model,
args.d_inner_hid,
args.prepostprocess_dropout,
args.attention_dropout,
args.relu_dropout,
args.preprocess_cmd,
args.postprocess_cmd,
args.weight_sharing,
args.beam_size,
args.max_out_len,
args.bos_idx,
args.eos_idx,
model_input=model_input)
out_ids, out_scores = fast_decode(
model_input, args.src_vocab_size, args.trg_vocab_size,
args.max_length + 1, args.n_layer, args.n_head, args.d_key,
args.d_value, args.d_model, args.d_inner_hid,
args.prepostprocess_dropout, args.attention_dropout,
args.relu_dropout, args.preprocess_cmd, args.postprocess_cmd,
args.weight_sharing, args.beam_size, args.max_out_len, args.bos_idx,
args.eos_idx)
return out_ids, out_scores
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
__all__ = ['check_gpu', 'check_version']
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
......@@ -93,7 +93,7 @@ class InputField(object):
self.feed_list_str = []
self.feed_list = []
self.reader = None
self.loader = None
if input_slots:
for input_slot in input_slots:
......@@ -135,22 +135,17 @@ class InputField(object):
for _name, _shape, _dtype, _lod_level in zip(
self.names, self.shapes, self.dtypes, self.lod_levels):
self.input_slots[_name] = fluid.layers.data(
self.input_slots[_name] = fluid.data(
name=_name, shape=_shape, dtype=_dtype, lod_level=_lod_level)
for name in self.feed_list_str:
self.feed_list.append(self.input_slots[name])
if build_pyreader:
self.reader = fluid.io.PyReader(
feed_list=self.feed_list, capacity=capacity, iterable=iterable)
def start(self, generator=None):
if generator is not None:
self.reader.decorate_batch_generator(generator)
self.reader.start()
self.loader = fluid.io.DataLoader.from_generator(
feed_list=self.feed_list,
capacity=capacity,
iterable=(not build_pyreader),
use_double_buffer=True)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册