提交 24cb7dfe 编写于 作者: X xuezhong

add weight decay

上级 3e82ae13
......@@ -58,6 +58,11 @@ def parse_args():
type=float,
default=0.001,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--weight_decay",
type=float,
default=0.001,
help="Weight decay. (default: %(default)f)")
parser.add_argument(
"--use_gpu",
type=distutils.util.strtobool,
......
......@@ -21,6 +21,7 @@ import time
import os
import random
import json
import six
import paddle
import paddle.fluid as fluid
......@@ -209,6 +210,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
'yesno_answers': []
})
if args.result_dir is not None and args.result_name is not None:
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
result_file = os.path.join(args.result_dir, args.result_name + '.json')
with open(result_file, 'w') as fout:
for pred_answer in pred_answers:
......@@ -235,7 +238,10 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
def train(logger, args):
logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
if six.PY2:
vocab = pickle.load(fin)
else:
vocab = pickle.load(fin, encoding='bytes')
logger.info('vocab size is {} and embed dim is {}'.format(vocab.size(
), vocab.embed_dim))
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
......@@ -259,13 +265,20 @@ def train(logger, args):
# build optimizer
if args.optim == 'sgd':
optimizer = fluid.optimizer.SGD(
learning_rate=args.learning_rate)
learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
elif args.optim == 'adam':
optimizer = fluid.optimizer.Adam(
learning_rate=args.learning_rate)
learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
elif args.optim == 'rprop':
optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=args.learning_rate)
learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
else:
logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册