diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 7e39a9e2d9c743681320eaa70e0d75476844018c..defb19a8a6ecda8f2a747033323ccb8eeda07281 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -95,7 +95,7 @@ def merge(teacher_program, def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, - student_var2_name, program): + student_var2_name, program=fluid.default_main_program()): """ Combine variables from student model and teacher model by fsp-loss. Args: @@ -121,7 +121,8 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, return fsp_loss -def l2_loss(teacher_var_name, student_var_name, program): +def l2_loss(teacher_var_name, student_var_name, + program=fluid.default_main_program()): """ Combine variables from student model and teacher model by l2-loss. Args: @@ -139,7 +140,7 @@ def l2_loss(teacher_var_name, student_var_name, program): def soft_label_loss(teacher_var_name, student_var_name, - program, + program=fluid.default_main_program(), teacher_temperature=1., student_temperature=1.): """ @@ -165,7 +166,7 @@ def soft_label_loss(teacher_var_name, return soft_label_loss -def loss(program, loss_func, **kwargs): +def loss(loss_func, program=fluid.default_main_program(), **kwargs): """ Combine variables from student model and teacher model by self defined loss. Args: