From 88f849b3deaa34b14bce2e9b7ebc47a8dcb4f97b Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Tue, 11 Feb 2020 11:17:36 +0800 Subject: [PATCH] Format KD english API (#102) --- paddleslim/dist/single_distiller.py | 90 +++++++++++++++++++---------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 5e04134d..c5824851 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -20,21 +20,31 @@ def merge(teacher_program, student_program, data_name_map, place, - scope=fluid.global_scope(), + scope=None, name_prefix='teacher_'): - """ - Merge teacher program into student program and add a uniform prefix to the + """Merge teacher program into student program and add a uniform prefix to the names of all vars in teacher program + Args: teacher_program(Program): The input teacher model paddle program student_program(Program): The input student model paddle program - data_map_map(dict): Describe the mapping between the teacher var name - and the student var name + data_map_map(dict): Mapping of teacher input interface name and student + input interface name, where key of dict is the + input name of teacher_program, and value is the + input name of student_program. place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents paddle run on which device. - scope(Scope): The input scope + scope(Scope): This parameter indicates the variable scope used by + the program. If not specified, the default global scope + will be used. Default: None name_prefix(str): Name prefix added for all vars of the teacher program. + Default: 'teacher_' + + Returns: + None """ + if scope==None: + scope = fluid.global_scope() teacher_program = teacher_program.clone(for_test=True) for teacher_var in teacher_program.list_vars(): skip_rename = False @@ -89,9 +99,9 @@ def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, student_var2_name, - program=fluid.default_main_program()): - """ - Combine variables from student model and teacher model by fsp-loss. + program=None): + """Combine variables from student model and teacher model by fsp-loss. + Args: teacher_var1_name(str): The name of teacher_var1. teacher_var2_name(str): The name of teacher_var2. Except for the @@ -101,10 +111,14 @@ def fsp_loss(teacher_var1_name, student_var2_name(str): The name of student_var2. Except for the second dimension, all other dimensions should be consistent with student_var1. - program(Program): The input distiller program. - default: fluid.default_main_program() - Return(Variable): fsp distiller loss. + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None + + Returns: + Variable: fsp distiller loss. """ + if program==None: + program=fluid.default_main_program() teacher_var1 = program.global_block().var(teacher_var1_name) teacher_var2 = program.global_block().var(teacher_var2_name) student_var1 = program.global_block().var(student_var1_name) @@ -118,16 +132,20 @@ def fsp_loss(teacher_var1_name, def l2_loss(teacher_var_name, student_var_name, - program=fluid.default_main_program()): - """ - Combine variables from student model and teacher model by l2-loss. + program=None): + """Combine variables from student model and teacher model by l2-loss. + Args: teacher_var_name(str): The name of teacher_var. student_var_name(str): The name of student_var. - program(Program): The input distiller program. - default: fluid.default_main_program() - Return(Variable): l2 distiller loss. + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None + + Returns: + Variable: l2 distiller loss. """ + if program==None: + program=fluid.default_main_program() student_var = program.global_block().var(student_var_name) teacher_var = program.global_block().var(teacher_var_name) l2_loss = fluid.layers.reduce_mean( @@ -137,22 +155,26 @@ def l2_loss(teacher_var_name, def soft_label_loss(teacher_var_name, student_var_name, - program=fluid.default_main_program(), + program=None, teacher_temperature=1., student_temperature=1.): - """ - Combine variables from student model and teacher model by soft-label-loss. + """Combine variables from student model and teacher model by soft-label-loss. + Args: teacher_var_name(str): The name of teacher_var. student_var_name(str): The name of student_var. - program(Program): The input distiller program. - default: fluid.default_main_program() + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None teacher_temperature(float): Temperature used to divide - teacher_feature_map before softmax. default: 1.0 + teacher_feature_map before softmax. Default: 1.0 student_temperature(float): Temperature used to divide - student_feature_map before softmax. default: 1.0 - Return(Variable): l2 distiller loss. + student_feature_map before softmax. Default: 1.0 + + Returns: + Variable: l2 distiller loss. """ + if program==None: + program=fluid.default_main_program() student_var = program.global_block().var(student_var_name) teacher_var = program.global_block().var(teacher_var_name) student_var = fluid.layers.softmax(student_var / student_temperature) @@ -164,15 +186,19 @@ def soft_label_loss(teacher_var_name, return soft_label_loss -def loss(loss_func, program=fluid.default_main_program(), **kwargs): - """ - Combine variables from student model and teacher model by self defined loss. +def loss(loss_func, program=None, **kwargs): + """Combine variables from student model and teacher model by self defined loss. + Args: - program(Program): The input distiller program. - default: fluid.default_main_program() + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None loss_func(function): The user self defined loss function. - Return(Variable): self defined distiller loss. + + Returns: + Variable: self defined distiller loss. """ + if program==None: + program=fluid.default_main_program() func_parameters = {} for item in kwargs.items(): if isinstance(item[1], str): -- GitLab