提交 8891f0e9 编写于 作者: S slf12

set default scope None test=develop

上级 34b504eb
......@@ -19,6 +19,7 @@ import logging
import copy
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
......@@ -216,7 +217,7 @@ def _quant_embedding_abs_max(graph, scope, place, config):
scope.var(_get_scale_var_name(var_name))
#set var by tensor array or scale
_restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place)
_restore_var(_get_scale_var_name(var_name), scale, scope, place)
_restore_var(_get_scale_var_name(var_name), np.array(scale), scope, place)
# insert dequantize_abs_max op
for op_node in embedding_node.outputs:
......@@ -231,12 +232,12 @@ def _quant_embedding_abs_max(graph, scope, place, config):
graph.safe_remove_nodes(embedding_node)
def quant_embedding(program, scope, place, config):
def quant_embedding(program, place, config, scope=None):
"""
quant lookup_table op parameters
Args:
program(fluid.Program): infer program
scope(fluid.Scope): the scope to store var, usually is fluid.global_scope()
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): config to quant. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \
......@@ -249,6 +250,7 @@ def quant_embedding(program, scope, place, config):
"""
assert isinstance(config, dict), "config must be dict"
config = _merge_config(copy.deepcopy(default_config), config)
scope = fluid.global_scope() if scope is None else scope
graph = IrGraph(core.Graph(program.desc), for_test=True)
if config['quantize_type'] == 'abs_max':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册