diff --git a/test_tipc/common_func.sh b/test_tipc/common_func.sh index e0459366ed7d86d239624dc47937d91cc7704894..4aa3db6ca1d8091c99bf9a50a946417d94b7a791 100644 --- a/test_tipc/common_func.sh +++ b/test_tipc/common_func.sh @@ -24,6 +24,17 @@ function func_parser_value_lite(){ echo ${tmp} } +function func_set_amp_params(){ + key=$1 + value=$2 + + if [[ ${value} = "fp16" ]];then + echo "-o AMP.scale_loss=128 -o AMP.use_dynamic_loss_scaling=True -o AMP.level=O2" + else + echo " " + fi +} + function func_set_params(){ key=$1 value=$2 diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh index ad5b301f1ef5bacdd82cafff35d3d61699b38151..9ec79bb29ce69b908fb7c003086e013c55de2517 100644 --- a/test_tipc/test_train_inference_python.sh +++ b/test_tipc/test_train_inference_python.sh @@ -231,7 +231,7 @@ else continue fi - set_autocast=$(func_set_params "${autocast_key}" "${autocast}") + set_autocast=$(func_set_amp_params "${autocast_key}" "${autocast}") set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}") set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}") set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")