提交 81c47b21 编写于 作者: L Luo Tao

add type check and default scope

上级 01b88f21
......@@ -13,26 +13,35 @@
# limitations under the License.
import numpy as np
import os
import shutil
from framework import Program
from executor import global_scope
from . import core
class InferenceTranspiler:
def transpile(self, program, scope, place):
def transpile(self, program, place, scope=None):
'''
Transpile the program. Support only fuse batch normalization now.
:param program: program to transpile
:type program: Program
:param scope: inference scope
:type scope: Scope
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope or None
'''
self.fuse_batch_norm(program, scope, place)
def fuse_batch_norm(self, program, scope, place):
if not isinstance(program, Program):
raise TypeError("program should be as Program type")
if not isinstance(place, core.CPUPlace) and not isinstance(
place, core.CUDAPlace):
raise TypeError("place should be as CPUPlace/CUDAPlace type")
if scope is None:
scope = global_scope()
if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None")
self.fuse_batch_norm(program, place, scope)
def fuse_batch_norm(self, program, place, scope):
'''
Transpile the program by fused batch normalization.
......@@ -66,10 +75,10 @@ class InferenceTranspiler:
:param program: program to transpile
:type program: Program
:param scope: inference scope
:type scope: Scope
:param place: inference place
:type place: Place
:param scope: inference scope
:type scope: Scope
'''
self.scope = scope
self.place = place
......
......@@ -229,7 +229,7 @@ def infer(use_cuda, save_dirname=None):
# Use inference_transpiler to speedup
inference_transpiler_program = inference_program.clone()
t = fluid.InferenceTranspiler()
t.transpile(inference_transpiler_program, inference_scope, place)
t.transpile(inference_transpiler_program, place)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册