diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 202aa76084432b4b2378470919b2e924301f2130..0629f2916b339a6cd19ccadf435a67a17d6da4cc 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -19,16 +19,30 @@ from ..executor import global_scope class InferenceTranspiler: + ''' + Convert the fluid program to optimized inference program. + + There are several optimizations, only fuse batch normalization is supported now. + + Examples: + + .. code-block:: python + + # As InferenceTranspiler will modify the original program, + # please clone before use it. + inference_transpiler_program = program.clone() + t = fluid.InferenceTranspiler() + t.transpile(inference_transpiler_program, place) + ''' + def transpile(self, program, place, scope=None): ''' - Transpile the program. Support only fuse batch normalization now. - - :param program: program to transpile - :type program: Program - :param place: inference place - :type place: Place - :param scope: inference scope - :type scope: Scope or None + Run the transpiler. + + Args: + program (Program): program to transpile + place (Place): inference place + scope (Scope|None): inference Scope ''' if not isinstance(program, Program): raise TypeError("program should be as Program type") @@ -49,36 +63,43 @@ class InferenceTranspiler: can be integrated with them. Doing so will give us a forward acceleration, especially in environments like mobile or embedded. - For input X: - - Conv process: X = input * W + bias - - Batch norm process: X' = (X - mean) / std - - Scale Process: Y = a * X' + b + For input :math:`X`: + + - Conv process: :math:`X = input * W + bias` + - Batch norm process: :math:`X' = (X - mean) / std` + - Scale Process: :math:`Y = a * X' + b` After fuse into one operation: - Y = (input * W + bias - mean) / std * a + b - = input * a * W / std + ((bias - mean) / std * a + b) + .. math:: + + Y &= (input * W + bias - mean) / std * a + b \\\\ + &= input * a * W / std + ((bias - mean) / std * a + b) The operator transformation is: + - before: + - conv->batch_norm->any_other_op (bias == 0) - conv->elementwise_add->batch_norm->any_other_op (bias != 0) + - after: + - conv->elementwise_add->any_other_op The transpile stages are: + 1. insert elementwise_add op when bias == 0. 2. fuse the batch_norm's parameters to conv and elementwise_add operators. 3. remove batch_norm ops which are not used in any other ops. 4. adjust the input of any_other_op to be the output of elementwise_add operator. 5. remove unused variables. - :param program: program to transpile - :type program: Program - :param place: inference place - :type place: Place - :param scope: inference scope - :type scope: Scope + Args: + program (Program): program to transpile + place (Place): inference place + scope (Scope): inference Scope + ''' self.scope = scope self.place = place