未验证 提交 c19f7ac1 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #1493 from phlrain/add_cudnn_lm

add cudnn lm
......@@ -7,5 +7,6 @@ python train.py \
--data_path data/simple-examples/data/ \
--model_type small \
--use_gpu True \
--rnn_model static \
--enable_ce | python _ce.py
......@@ -26,7 +26,12 @@ def parse_args():
"--model_type",
type=str,
default="small",
help="model_type [test|small|med|big]")
help="model_type [test|small|medium|large]")
parser.add_argument(
"--rnn_model",
type=str,
default="static",
help="model_type [static|padding|cudnn]")
parser.add_argument(
"--data_path", type=str, help="all the data for train,valid,test")
parser.add_argument('--para_init', action='store_true')
......
......@@ -28,7 +28,8 @@ def lm_model(hidden_size,
num_layers=2,
num_steps=20,
init_scale=0.1,
dropout=None):
dropout=None,
rnn_model='static'):
def padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None):
weight_1_arr = []
weight_2_arr = []
......@@ -243,7 +244,7 @@ def lm_model(hidden_size,
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=True,
is_sparse=False,
param_attr=fluid.ParamAttr(
name='embedding_para',
initializer=fluid.initializer.UniformInitializer(
......@@ -255,9 +256,22 @@ def lm_model(hidden_size,
x_emb,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
rnn_out, last_hidden, last_cell = padding_rnn(
x_emb, len=num_steps, init_hidden=init_hidden, init_cell=init_cell)
if rnn_model == "padding":
rnn_out, last_hidden, last_cell = padding_rnn(
x_emb, len=num_steps, init_hidden=init_hidden, init_cell=init_cell)
elif rnn_model == "static":
rnn_out, last_hidden, last_cell = encoder_static(
x_emb, len=num_steps, init_hidden=init_hidden, init_cell=init_cell)
elif rnn_model == "cudnn":
x_emb = layers.transpose( x_emb, perm=[1, 0, 2])
rnn_out, last_hidden, last_cell = layers.lstm( x_emb, init_hidden, init_cell, num_steps, hidden_size, num_layers, \
is_bidirec=False, \
default_initializer=fluid.initializer.UniformInitializer(low=-init_scale, high=init_scale) )
rnn_out = layers.transpose( rnn_out, perm=[1, 0, 2])
else:
print( "type not support")
return
rnn_out = layers.reshape(rnn_out, shape=[-1, num_steps, hidden_size])
......
......@@ -77,6 +77,7 @@ def save_para_npz(train_prog, train_exe):
def train():
args = parse_args()
model_type = args.model_type
rnn_model = args.rnn_model
logger = logging.getLogger("lm")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
......@@ -157,7 +158,8 @@ def train():
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale,
dropout=dropout)
dropout=dropout,
rnn_model=rnn_model)
# clone from default main program and use it as the validation program
main_program = fluid.default_main_program()
inference_program = fluid.default_main_program().clone(for_test=True)
......@@ -206,18 +208,19 @@ def train():
def eval(data):
# when eval the batch_size set to 1
eval_data_iter = reader.get_data_iter(data, 1, num_steps)
eval_data_iter = reader.get_data_iter(data, batch_size, num_steps)
total_loss = 0.0
iters = 0
init_hidden = np.zeros((num_layers, 1, hidden_size), dtype='float32')
init_cell = np.zeros((num_layers, 1, hidden_size), dtype='float32')
init_hidden = np.zeros((num_layers, batch_size, hidden_size), dtype='float32')
init_cell = np.zeros((num_layers, batch_size, hidden_size), dtype='float32')
for batch_id, batch in enumerate(eval_data_iter):
input_data_feed = prepare_input(
batch, init_hidden, init_cell, epoch_id, with_lr=False)
fetch_outs = exe.run(
inference_program,
feed=input_data_feed,
fetch_list=[loss.name, last_hidden.name, last_cell.name])
fetch_list=[loss.name, last_hidden.name, last_cell.name],
use_program_cache=True)
cost_train = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1])
......@@ -284,7 +287,7 @@ def train():
if epoch_id == max_epoch - 1 and args.enable_ce:
print("ptblm\tlstm_language_model_duration\t%s" %
(total_time / max_epoch))
(total_time / max_epoch))
print("ptblm\tlstm_language_model_loss\t%s" % ppl[0])
model_path = os.path.join("model_new/", str(epoch_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册