提交 81c47b21 编写于 作者: L Luo Tao

add type check and default scope

上级 01b88f21
...@@ -13,26 +13,35 @@ ...@@ -13,26 +13,35 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import os from framework import Program
import shutil from executor import global_scope
from . import core from . import core
class InferenceTranspiler: class InferenceTranspiler:
def transpile(self, program, scope, place): def transpile(self, program, place, scope=None):
''' '''
Transpile the program. Support only fuse batch normalization now. Transpile the program. Support only fuse batch normalization now.
:param program: program to transpile :param program: program to transpile
:type program: Program :type program: Program
:param scope: inference scope
:type scope: Scope
:param place: inference place :param place: inference place
:type place: Place :type place: Place
:param scope: inference scope
:type scope: Scope or None
''' '''
self.fuse_batch_norm(program, scope, place) if not isinstance(program, Program):
raise TypeError("program should be as Program type")
def fuse_batch_norm(self, program, scope, place): if not isinstance(place, core.CPUPlace) and not isinstance(
place, core.CUDAPlace):
raise TypeError("place should be as CPUPlace/CUDAPlace type")
if scope is None:
scope = global_scope()
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
self.fuse_batch_norm(program, place, scope)
def fuse_batch_norm(self, program, place, scope):
''' '''
Transpile the program by fused batch normalization. Transpile the program by fused batch normalization.
...@@ -66,10 +75,10 @@ class InferenceTranspiler: ...@@ -66,10 +75,10 @@ class InferenceTranspiler:
:param program: program to transpile :param program: program to transpile
:type program: Program :type program: Program
:param scope: inference scope
:type scope: Scope
:param place: inference place :param place: inference place
:type place: Place :type place: Place
:param scope: inference scope
:type scope: Scope
''' '''
self.scope = scope self.scope = scope
self.place = place self.place = place
......
...@@ -229,7 +229,7 @@ def infer(use_cuda, save_dirname=None): ...@@ -229,7 +229,7 @@ def infer(use_cuda, save_dirname=None):
# Use inference_transpiler to speedup # Use inference_transpiler to speedup
inference_transpiler_program = inference_program.clone() inference_transpiler_program = inference_program.clone()
t = fluid.InferenceTranspiler() t = fluid.InferenceTranspiler()
t.transpile(inference_transpiler_program, inference_scope, place) t.transpile(inference_transpiler_program, place)
# Construct feed as a dictionary of {feed_target_name: feed_target_data} # Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets. # and results will contain a list of data corresponding to fetch_targets.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册