diff --git a/PaddleCV/image_classification/utils/fp16_utils.py b/PaddleCV/image_classification/utils/fp16_utils.py index 939ac59db2441af7b190604895f8a467d9844294..bc55f77d7856c9efde805aa82cb59d1193818623 100644 --- a/PaddleCV/image_classification/utils/fp16_utils.py +++ b/PaddleCV/image_classification/utils/fp16_utils.py @@ -103,8 +103,16 @@ def create_master_params_grads(params_grads, main_prog, startup_prog, scale_loss def master_param_to_train_param(master_params_grads, params_grads, main_prog): for idx, m_p_g in enumerate(master_params_grads): - train_p, _ = params_grads[idx] - if train_p.name.startswith("batch_norm"): - continue with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): + train_p_name = m_p_g[0].name.replace(".master", "") + if train_p_name.startswith("batch_norm"): + continue + train_p = None + # find fp16 param in original params_grads list + for p, g in params_grads: + if p.name == train_p_name: + train_p = p + if not train_p: + print("can not find train param for: ", m_p_g[0].name) + continue cast_fp32_to_fp16(m_p_g[0], train_p, main_prog)