未验证 提交 3f8aa0f7 编写于 作者: H Hongyu Liu 提交者: GitHub

add lstm api support (#2545)

* add lstm api support; test=develop

* change name from basic_api to basic_lstm; test=develop

* fix format problem; test=develop
上级 c25124db
...@@ -40,7 +40,7 @@ def parse_args(): ...@@ -40,7 +40,7 @@ def parse_args():
"--rnn_model", "--rnn_model",
type=str, type=str,
default="static", default="static",
help="model_type [static|padding|cudnn]") help="model_type [static|padding|cudnn|basic_lstm]")
parser.add_argument( parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test") "--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument('--para_init', action='store_true') parser.add_argument('--para_init', action='store_true')
......
...@@ -70,7 +70,7 @@ class RNNConfig(object): ...@@ -70,7 +70,7 @@ class RNNConfig(object):
else: else:
raise ValueError('Unsupported model_type.') raise ValueError('Unsupported model_type.')
if args.rnn_model not in ('static', 'padding', 'cudnn'): if args.rnn_model not in ('static', 'padding', 'cudnn', 'basic_lstm'):
raise ValueError('Unsupported rnn_model.') raise ValueError('Unsupported rnn_model.')
if args.batch_size > 0: if args.batch_size > 0:
......
...@@ -9,4 +9,4 @@ function run_train() { ...@@ -9,4 +9,4 @@ function run_train() {
--use_gpu True --use_gpu True
} }
run_train run_train
\ No newline at end of file
...@@ -233,7 +233,7 @@ def main(): ...@@ -233,7 +233,7 @@ def main():
program=inference_program, program=inference_program,
feed=input_data_feed, feed=input_data_feed,
fetch_list=[loss.name, last_hidden.name, last_cell.name], fetch_list=[loss.name, last_hidden.name, last_cell.name],
use_program_cache=True) use_program_cache=False)
cost_eval = np.array(fetch_outs[0]) cost_eval = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1]) init_hidden = np.array(fetch_outs[1])
...@@ -259,12 +259,9 @@ def main(): ...@@ -259,12 +259,9 @@ def main():
total_loss = 0 total_loss = 0
iters = 0 iters = 0
init_hidden, init_cell = generate_init_data()
for batch_id, batch in enumerate(train_data_iter): for batch_id, batch in enumerate(train_data_iter):
if batch_id == 0:
init_hidden, init_cell = generate_init_data()
else:
init_hidden = None
init_cell = None
input_data_feed = prepare_input( input_data_feed = prepare_input(
batch, batch,
init_hidden=init_hidden, init_hidden=init_hidden,
...@@ -276,13 +273,16 @@ def main(): ...@@ -276,13 +273,16 @@ def main():
batch_start_time = time.time() batch_start_time = time.time()
fetch_outs = exe.run(train_program, fetch_outs = exe.run(train_program,
feed=input_data_feed, feed=input_data_feed,
fetch_list=[loss.name, "learning_rate"], fetch_list=[loss.name, "learning_rate", \
last_hidden.name, last_cell.name ],
use_program_cache=True) use_program_cache=True)
batch_time = time.time() - batch_start_time batch_time = time.time() - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
cost_train = np.array(fetch_outs[0]) cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1]) lr = np.array(fetch_outs[1])
init_hidden = np.array(fetch_outs[2])
init_cell = np.array(fetch_outs[3])
total_loss += cost_train total_loss += cost_train
iters += config.num_steps iters += config.num_steps
...@@ -312,8 +312,6 @@ def main(): ...@@ -312,8 +312,6 @@ def main():
if batch_id == 0: if batch_id == 0:
batch_time = 0 batch_time = 0
batch_start_time = time.time() batch_start_time = time.time()
data_feeds["init_hidden"] = init_hidden
data_feeds["init_cell"] = init_cell
else: else:
batch_time = time.time() - batch_start_time batch_time = time.time() - batch_start_time
batch_times.append(batch_time) batch_times.append(batch_time)
...@@ -321,14 +319,19 @@ def main(): ...@@ -321,14 +319,19 @@ def main():
new_lr = generate_new_lr(epoch_id, device_count) new_lr = generate_new_lr(epoch_id, device_count)
data_feeds['learning_rate'] = new_lr data_feeds['learning_rate'] = new_lr
data_feeds["init_hidden"] = init_hidden
data_feeds["init_cell"] = init_cell
fetch_outs = exe.run(train_program, fetch_outs = exe.run(train_program,
feed=data_feeds, feed=data_feeds,
fetch_list=[loss.name, "learning_rate"], fetch_list=[loss.name, "learning_rate", \
last_hidden.name, last_cell.name ],
use_program_cache=True) use_program_cache=True)
cost_train = np.array(fetch_outs[0]) cost_train = np.array(fetch_outs[0])
lr = np.array(fetch_outs[1]) lr = np.array(fetch_outs[1])
init_hidden = np.array(fetch_list[2])
init_cell = np.array( fetch_list[3] )
total_loss += cost_train total_loss += cost_train
iters += config.num_steps iters += config.num_steps
......
...@@ -20,6 +20,8 @@ import paddle.fluid.layers as layers ...@@ -20,6 +20,8 @@ import paddle.fluid.layers as layers
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN
import numpy as np import numpy as np
from paddle.fluid import ParamAttr
from paddle.fluid.contrib.layers import basic_lstm
def lm_model(hidden_size, def lm_model(hidden_size,
...@@ -358,6 +360,10 @@ def lm_model(hidden_size, ...@@ -358,6 +360,10 @@ def lm_model(hidden_size,
default_initializer=fluid.initializer.UniformInitializer( default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)) low=-init_scale, high=init_scale))
rnn_out = layers.transpose(rnn_out, perm=[1, 0, 2]) rnn_out = layers.transpose(rnn_out, perm=[1, 0, 2])
elif rnn_model == "basic_lstm":
print("basic api")
rnn_out, last_hidden, last_cell = basic_lstm( x_emb, init_hidden, init_cell, hidden_size, num_layers=num_layers, \
batch_first=True, dropout_prob=dropout, param_attr = ParamAttr( initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale) ), bias_attr = ParamAttr( initializer = fluid.initializer.Constant(0.0) ))
else: else:
print("type not support") print("type not support")
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册