未验证 提交 3f8aa0f7 编写于 作者: H Hongyu Liu 提交者: GitHub

add lstm api support (#2545)

* add lstm api support; test=develop

* change name from basic_api to basic_lstm; test=develop

* fix format problem; test=develop
上级 c25124db
......@@ -40,7 +40,7 @@ def parse_args():
"--rnn_model",
type=str,
default="static",
help="model_type [static|padding|cudnn]")
help="model_type [static|padding|cudnn|basic_lstm]")
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument('--para_init', action='store_true')
......
......@@ -70,7 +70,7 @@ class RNNConfig(object):
else:
raise ValueError('Unsupported model_type.')
if args.rnn_model not in ('static', 'padding', 'cudnn'):
if args.rnn_model not in ('static', 'padding', 'cudnn', 'basic_lstm'):
raise ValueError('Unsupported rnn_model.')
if args.batch_size > 0:
......
......@@ -233,7 +233,7 @@ def main():
program=inference_program,
feed=input_data_feed,
fetch_list=[loss.name, last_hidden.name, last_cell.name],
use_program_cache=True)
use_program_cache=False)
cost_eval = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1])
......@@ -259,12 +259,9 @@ def main():
total_loss = 0
iters = 0
for batch_id, batch in enumerate(train_data_iter):
if batch_id == 0:
init_hidden, init_cell = generate_init_data()
else:
init_hidden = None
init_cell = None
for batch_id, batch in enumerate(train_data_iter):
input_data_feed = prepare_input(
batch,
init_hidden=init_hidden,
......@@ -276,13 +273,16 @@ def main():
batch_start_time = time.time()
fetch_outs = exe.run(train_program,
feed=input_data_feed,
fetch_list=[loss.name, "learning_rate"],
fetch_list=[loss.name, "learning_rate", \
last_hidden.name, last_cell.name ],
use_program_cache=True)
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1])
init_hidden = np.array(fetch_outs[2])
init_cell = np.array(fetch_outs[3])
total_loss += cost_train
iters += config.num_steps
......@@ -312,8 +312,6 @@ def main():
if batch_id == 0:
batch_time = 0
batch_start_time = time.time()
data_feeds["init_hidden"] = init_hidden
data_feeds["init_cell"] = init_cell
else:
batch_time = time.time() - batch_start_time
batch_times.append(batch_time)
......@@ -321,14 +319,19 @@ def main():
new_lr = generate_new_lr(epoch_id, device_count)
data_feeds['learning_rate'] = new_lr
data_feeds["init_hidden"] = init_hidden
data_feeds["init_cell"] = init_cell
fetch_outs = exe.run(train_program,
feed=data_feeds,
fetch_list=[loss.name, "learning_rate"],
fetch_list=[loss.name, "learning_rate", \
last_hidden.name, last_cell.name ],
use_program_cache=True)
cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1])
init_hidden = np.array(fetch_list[2])
init_cell = np.array( fetch_list[3] )
total_loss += cost_train
iters += config.num_steps
......
......@@ -20,6 +20,8 @@ import paddle.fluid.layers as layers
import paddle.fluid as fluid
from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN
import numpy as np
from paddle.fluid import ParamAttr
from paddle.fluid.contrib.layers import basic_lstm
def lm_model(hidden_size,
......@@ -358,6 +360,10 @@ def lm_model(hidden_size,
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))
rnn_out = layers.transpose(rnn_out, perm=[1, 0, 2])
elif rnn_model == "basic_lstm":
print("basic api")
rnn_out, last_hidden, last_cell = basic_lstm( x_emb, init_hidden, init_cell, hidden_size, num_layers=num_layers, \
batch_first=True, dropout_prob=dropout, param_attr = ParamAttr( initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale) ), bias_attr = ParamAttr( initializer = fluid.initializer.Constant(0.0) ))
else:
print("type not support")
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册