From 5beed1ecc8469c3ae1cb842659803d5bcf5593cc Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Mon, 25 Nov 2019 17:30:12 +0800 Subject: [PATCH] make loss function program arg dispensableOC --- paddleslim/dist/single_distiller.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 7e39a9e2..defb19a8 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -95,7 +95,7 @@ def merge(teacher_program, def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, - student_var2_name, program): + student_var2_name, program=fluid.default_main_program()): """ Combine variables from student model and teacher model by fsp-loss. Args: @@ -121,7 +121,8 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, return fsp_loss -def l2_loss(teacher_var_name, student_var_name, program): +def l2_loss(teacher_var_name, student_var_name, + program=fluid.default_main_program()): """ Combine variables from student model and teacher model by l2-loss. Args: @@ -139,7 +140,7 @@ def l2_loss(teacher_var_name, student_var_name, program): def soft_label_loss(teacher_var_name, student_var_name, - program, + program=fluid.default_main_program(), teacher_temperature=1., student_temperature=1.): """ @@ -165,7 +166,7 @@ def soft_label_loss(teacher_var_name, return soft_label_loss -def loss(program, loss_func, **kwargs): +def loss(loss_func, program=fluid.default_main_program(), **kwargs): """ Combine variables from student model and teacher model by self defined loss. Args: -- GitLab