提交 93107ce1 编写于 作者: C chengduoZH

add regularization for test_machine_tranlation

上级 dd8dc0e0
...@@ -176,7 +176,6 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -176,7 +176,6 @@ class L1DecayRegularizer(WeightDecayRegularizer):
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS: if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# add concat_rows
decay = block.create_var( decay = block.create_var(
dtype="float32", dtype="float32",
shape=param.shape, shape=param.shape,
......
...@@ -181,7 +181,10 @@ def train_main(use_cuda, is_sparse, is_local=True): ...@@ -181,7 +181,10 @@ def train_main(use_cuda, is_sparse, is_local=True):
cost = pd.cross_entropy(input=rnn_out, label=label) cost = pd.cross_entropy(input=rnn_out, label=label)
avg_cost = pd.mean(cost) avg_cost = pd.mean(cost)
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4) optimizer = fluid.optimizer.Adagrad(
learning_rate=1e-4,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.1))
optimize_ops, params_grads = optimizer.minimize(avg_cost) optimize_ops, params_grads = optimizer.minimize(avg_cost)
train_data = paddle.batch( train_data = paddle.batch(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册