提交 19ed6716 编写于 作者: 一米半's avatar 一米半 提交者: pkpk

add check before using cuda and change default config (#2726)

上级 460e5431
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
}, },
"optimizer": { "optimizer": {
"class_name": "AdamOptimizer", "class_name": "AdamOptimizer",
"learning_rate": 0.001, "learning_rate": 0.2,
"beta1": 0.9, "beta1": 0.9,
"beta2": 0.999, "beta2": 0.999,
"epsilon": 1e-08 "epsilon": 1e-08
......
...@@ -73,7 +73,10 @@ run_type_g.add_arg( ...@@ -73,7 +73,10 @@ run_type_g.add_arg(
"When task_mode is pairwise, lamda is the threshold for calculating the accuracy." "When task_mode is pairwise, lamda is the threshold for calculating the accuracy."
) )
parser.add_argument('--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.') parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run the task with continuous evaluation logs.')
args = parser.parse_args() args = parser.parse_args()
...@@ -280,9 +283,9 @@ def train(conf_dict, args): ...@@ -280,9 +283,9 @@ def train(conf_dict, args):
except: except:
logging.info("ce info err!") logging.info("ce info err!")
print("kpis\teach_step_duration_%s_card%s\t%s" % print("kpis\teach_step_duration_%s_card%s\t%s" %
(args.task_name, card_num, ce_time)) (args.task_name, card_num, ce_time))
print("kpis\ttrain_loss_%s_card%s\t%f" % print("kpis\ttrain_loss_%s_card%s\t%f" %
(args.task_name, card_num, ce_loss)) (args.task_name, card_num, ce_loss))
if args.do_test: if args.do_test:
if args.task_mode == "pairwise": if args.task_mode == "pairwise":
...@@ -454,6 +457,15 @@ def main(conf_dict, args): ...@@ -454,6 +457,15 @@ def main(conf_dict, args):
if __name__ == "__main__": if __name__ == "__main__":
utils.print_arguments(args) utils.print_arguments(args)
try:
if fluid.is_compiled_with_cuda() != True and args.use_cuda == True:
print(
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\nPlease: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
)
sys.exit(1)
except Exception as e:
pass
utils.init_log("./log/TextSimilarityNet") utils.init_log("./log/TextSimilarityNet")
conf_dict = config.SimNetConfig(args) conf_dict = config.SimNetConfig(args)
main(conf_dict, args) main(conf_dict, args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册