未验证 提交 91cfa6f8 编写于 作者: Z zhang wenhui 提交者: GitHub

fix dygraph in gru (#4568)

* fix dygraph in gru

* fix dy gru4rec
上级 4c3e413f
......@@ -247,31 +247,7 @@ def train_ptb_lm():
model_type = args.model_type
vocab_size = 37484
if model_type == "test":
num_layers = 1
batch_size = 2
hidden_size = 10
num_steps = 4
init_scale = 0.1
max_grad_norm = 5.0
epoch_start_decay = 1
max_epoch = 1
dropout = 0.0
lr_decay = 0.5
base_learning_rate = 1.0
elif model_type == "small":
num_layers = 2
batch_size = 20
hidden_size = 200
num_steps = 20
init_scale = 0.1
max_grad_norm = 5.0
epoch_start_decay = 4
max_epoch = 2
dropout = 0.0
lr_decay = 0.5
base_learning_rate = 1.0
elif model_type == "gru4rec":
if model_type == "gru4rec":
num_layers = 1
batch_size = 500
hidden_size = 100
......@@ -283,30 +259,6 @@ def train_ptb_lm():
dropout = 0.0
lr_decay = 0.5
base_learning_rate = 0.05
elif model_type == "medium":
num_layers = 2
batch_size = 20
hidden_size = 650
num_steps = 35
init_scale = 0.05
max_grad_norm = 5.0
epoch_start_decay = 6
max_epoch = 39
dropout = 0.5
lr_decay = 0.8
base_learning_rate = 1.0
elif model_type == "large":
num_layers = 2
batch_size = 20
hidden_size = 1500
num_steps = 35
init_scale = 0.04
max_grad_norm = 10.0
epoch_start_decay = 14
max_epoch = 55
dropout = 0.65
lr_decay = 1.0 / 1.15
base_learning_rate = 1.0
else:
print("model type not support")
return
......@@ -433,7 +385,7 @@ def train_ptb_lm():
out_loss = dy_loss.numpy()
acc_ = acc.numpy()[0]
init_hidden = last_hidden
init_hidden = last_hidden.detach()
dy_loss.backward()
sgd.minimize(dy_loss)
ptb_model.clear_gradients()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册