提交 95a95672 编写于 作者: W wanghaoshuang

Merge branch 'develop' into 'develop'

Make distillation loss function program arg dispensable

See merge request !53
...@@ -95,7 +95,7 @@ def merge(teacher_program, ...@@ -95,7 +95,7 @@ def merge(teacher_program,
def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name,
student_var2_name, program): student_var2_name, program=fluid.default_main_program()):
""" """
Combine variables from student model and teacher model by fsp-loss. Combine variables from student model and teacher model by fsp-loss.
Args: Args:
...@@ -107,7 +107,8 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, ...@@ -107,7 +107,8 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name,
student_var2_name(str): The name of student_var2. Except for the student_var2_name(str): The name of student_var2. Except for the
second dimension, all other dimensions should second dimension, all other dimensions should
be consistent with student_var1. be consistent with student_var1.
program(Program): The input distiller program. program(Program): The input distiller program.
default: fluid.default_main_program()
Return(Variable): fsp distiller loss. Return(Variable): fsp distiller loss.
""" """
teacher_var1 = program.global_block().var(teacher_var1_name) teacher_var1 = program.global_block().var(teacher_var1_name)
...@@ -121,13 +122,15 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, ...@@ -121,13 +122,15 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name,
return fsp_loss return fsp_loss
def l2_loss(teacher_var_name, student_var_name, program): def l2_loss(teacher_var_name, student_var_name,
program=fluid.default_main_program()):
""" """
Combine variables from student model and teacher model by l2-loss. Combine variables from student model and teacher model by l2-loss.
Args: Args:
teacher_var_name(str): The name of teacher_var. teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var. student_var_name(str): The name of student_var.
program(Program): The input distiller program. program(Program): The input distiller program.
default: fluid.default_main_program()
Return(Variable): l2 distiller loss. Return(Variable): l2 distiller loss.
""" """
student_var = program.global_block().var(student_var_name) student_var = program.global_block().var(student_var_name)
...@@ -139,7 +142,7 @@ def l2_loss(teacher_var_name, student_var_name, program): ...@@ -139,7 +142,7 @@ def l2_loss(teacher_var_name, student_var_name, program):
def soft_label_loss(teacher_var_name, def soft_label_loss(teacher_var_name,
student_var_name, student_var_name,
program, program=fluid.default_main_program(),
teacher_temperature=1., teacher_temperature=1.,
student_temperature=1.): student_temperature=1.):
""" """
...@@ -147,7 +150,8 @@ def soft_label_loss(teacher_var_name, ...@@ -147,7 +150,8 @@ def soft_label_loss(teacher_var_name,
Args: Args:
teacher_var_name(str): The name of teacher_var. teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var. student_var_name(str): The name of student_var.
program(Program): The input distiller program. program(Program): The input distiller program.
default: fluid.default_main_program()
teacher_temperature(float): Temperature used to divide teacher_temperature(float): Temperature used to divide
teacher_feature_map before softmax. default: 1.0 teacher_feature_map before softmax. default: 1.0
student_temperature(float): Temperature used to divide student_temperature(float): Temperature used to divide
...@@ -165,11 +169,12 @@ def soft_label_loss(teacher_var_name, ...@@ -165,11 +169,12 @@ def soft_label_loss(teacher_var_name,
return soft_label_loss return soft_label_loss
def loss(program, loss_func, **kwargs): def loss(loss_func, program=fluid.default_main_program(), **kwargs):
""" """
Combine variables from student model and teacher model by self defined loss. Combine variables from student model and teacher model by self defined loss.
Args: Args:
program(Program): The input distiller program. program(Program): The input distiller program.
default: fluid.default_main_program()
loss_func(function): The user self defined loss function. loss_func(function): The user self defined loss function.
Return(Variable): self defined distiller loss. Return(Variable): self defined distiller loss.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册