提交 90a1b791 编写于 作者: X xuezhong

fix bug

上级 dd353cbd
......@@ -369,9 +369,11 @@ def train(logger, args):
logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1)
if args.weight_decay > 0.0:
avg_cost_wd = avg_cost + args.weight_decay * l2_loss(
main_program)
optimizer.minimize(avg_cost_wd)
obj_func = avg_cost + args.weight_decay * l2_loss(main_program)
optimizer.minimize(obj_func)
else:
obj_func = avg_cost
optimizer.minimize(obj_func)
# initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
......@@ -411,7 +413,7 @@ def train(logger, args):
feed_data = batch_reader(batch_list, args)
fetch_outs = parallel_executor.run(
feed=list(feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost_wd.name],
fetch_list=[obj_func.name],
return_numpy=False)
cost_train = np.array(fetch_outs[0]).mean()
total_num += args.batch_size * dev_count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册