提交 90a1b791 编写于 作者: X xuezhong

fix bug

上级 dd353cbd
...@@ -369,9 +369,11 @@ def train(logger, args): ...@@ -369,9 +369,11 @@ def train(logger, args):
logger.error('Unsupported optimizer: {}'.format(args.optim)) logger.error('Unsupported optimizer: {}'.format(args.optim))
exit(-1) exit(-1)
if args.weight_decay > 0.0: if args.weight_decay > 0.0:
avg_cost_wd = avg_cost + args.weight_decay * l2_loss( obj_func = avg_cost + args.weight_decay * l2_loss(main_program)
main_program) optimizer.minimize(obj_func)
optimizer.minimize(avg_cost_wd) else:
obj_func = avg_cost
optimizer.minimize(obj_func)
# initialize parameters # initialize parameters
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace() place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
...@@ -411,7 +413,7 @@ def train(logger, args): ...@@ -411,7 +413,7 @@ def train(logger, args):
feed_data = batch_reader(batch_list, args) feed_data = batch_reader(batch_list, args)
fetch_outs = parallel_executor.run( fetch_outs = parallel_executor.run(
feed=list(feeder.feed_parallel(feed_data, dev_count)), feed=list(feeder.feed_parallel(feed_data, dev_count)),
fetch_list=[avg_cost_wd.name], fetch_list=[obj_func.name],
return_numpy=False) return_numpy=False)
cost_train = np.array(fetch_outs[0]).mean() cost_train = np.array(fetch_outs[0]).mean()
total_num += args.batch_size * dev_count total_num += args.batch_size * dev_count
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册