From 81c47b21ef742dca9a7bfad16059575ce57f20aa Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 17 Apr 2018 19:38:20 +0800 Subject: [PATCH] add type check and default scope --- python/paddle/fluid/inference_transpiler.py | 29 ++++++++++++------- .../tests/book/test_image_classification.py | 2 +- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/inference_transpiler.py b/python/paddle/fluid/inference_transpiler.py index be8a62795..39b01610f 100644 --- a/python/paddle/fluid/inference_transpiler.py +++ b/python/paddle/fluid/inference_transpiler.py @@ -13,26 +13,35 @@ # limitations under the License. import numpy as np -import os -import shutil +from framework import Program +from executor import global_scope from . import core class InferenceTranspiler: - def transpile(self, program, scope, place): + def transpile(self, program, place, scope=None): ''' Transpile the program. Support only fuse batch normalization now. :param program: program to transpile :type program: Program - :param scope: inference scope - :type scope: Scope :param place: inference place :type place: Place + :param scope: inference scope + :type scope: Scope or None ''' - self.fuse_batch_norm(program, scope, place) - - def fuse_batch_norm(self, program, scope, place): + if not isinstance(program, Program): + raise TypeError("program should be as Program type") + if not isinstance(place, core.CPUPlace) and not isinstance( + place, core.CUDAPlace): + raise TypeError("place should be as CPUPlace/CUDAPlace type") + if scope is None: + scope = global_scope() + if not isinstance(scope, core.Scope): + raise TypeError("scope should be as Scope type or None") + self.fuse_batch_norm(program, place, scope) + + def fuse_batch_norm(self, program, place, scope): ''' Transpile the program by fused batch normalization. @@ -66,10 +75,10 @@ class InferenceTranspiler: :param program: program to transpile :type program: Program - :param scope: inference scope - :type scope: Scope :param place: inference place :type place: Place + :param scope: inference scope + :type scope: Scope ''' self.scope = scope self.place = place diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py index aeacca575..0027b651e 100644 --- a/python/paddle/fluid/tests/book/test_image_classification.py +++ b/python/paddle/fluid/tests/book/test_image_classification.py @@ -229,7 +229,7 @@ def infer(use_cuda, save_dirname=None): # Use inference_transpiler to speedup inference_transpiler_program = inference_program.clone() t = fluid.InferenceTranspiler() - t.transpile(inference_transpiler_program, inference_scope, place) + t.transpile(inference_transpiler_program, place) # Construct feed as a dictionary of {feed_target_name: feed_target_data} # and results will contain a list of data corresponding to fetch_targets. -- GitLab