提交 24cb7dfe 编写于 作者: X xuezhong

add weight decay

上级 3e82ae13
...@@ -58,6 +58,11 @@ def parse_args(): ...@@ -58,6 +58,11 @@ def parse_args():
type=float, type=float,
default=0.001, default=0.001,
help="Learning rate used to train the model. (default: %(default)f)") help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--weight_decay",
type=float,
default=0.001,
help="Weight decay. (default: %(default)f)")
parser.add_argument( parser.add_argument(
"--use_gpu", "--use_gpu",
type=distutils.util.strtobool, type=distutils.util.strtobool,
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import os import os
import random import random
import json import json
import six
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -209,6 +210,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, ...@@ -209,6 +210,8 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
'yesno_answers': [] 'yesno_answers': []
}) })
if args.result_dir is not None and args.result_name is not None: if args.result_dir is not None and args.result_name is not None:
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
result_file = os.path.join(args.result_dir, args.result_name + '.json') result_file = os.path.join(args.result_dir, args.result_name + '.json')
with open(result_file, 'w') as fout: with open(result_file, 'w') as fout:
for pred_answer in pred_answers: for pred_answer in pred_answers:
...@@ -235,7 +238,10 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place, ...@@ -235,7 +238,10 @@ def validation(inference_program, avg_cost, s_probs, e_probs, feed_order, place,
def train(logger, args): def train(logger, args):
logger.info('Load data_set and vocab...') logger.info('Load data_set and vocab...')
with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin: with open(os.path.join(args.vocab_dir, 'vocab.data'), 'rb') as fin:
vocab = pickle.load(fin) if six.PY2:
vocab = pickle.load(fin)
else:
vocab = pickle.load(fin, encoding='bytes')
logger.info('vocab size is {} and embed dim is {}'.format(vocab.size( logger.info('vocab size is {} and embed dim is {}'.format(vocab.size(
), vocab.embed_dim)) ), vocab.embed_dim))
brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len, brc_data = BRCDataset(args.max_p_num, args.max_p_len, args.max_q_len,
...@@ -259,13 +265,20 @@ def train(logger, args): ...@@ -259,13 +265,20 @@ def train(logger, args):
# build optimizer # build optimizer
if args.optim == 'sgd': if args.optim == 'sgd':
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
learning_rate=args.learning_rate) learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
elif args.optim == 'adam': elif args.optim == 'adam':
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
learning_rate=args.learning_rate) learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
elif args.optim == 'rprop': elif args.optim == 'rprop':
optimizer = fluid.optimizer.RMSPropOptimizer( optimizer = fluid.optimizer.RMSPropOptimizer(
learning_rate=args.learning_rate) learning_rate=args.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=args.weight_decay))
else: else:
logger.error('Unsupported optimizer: {}'.format(args.optim)) logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1) exit(-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册