提交 8891f0e9 编写于 作者: S slf12

set default scope None test=develop

上级 34b504eb
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
import copy import copy
import numpy as np import numpy as np
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -216,7 +217,7 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -216,7 +217,7 @@ def _quant_embedding_abs_max(graph, scope, place, config):
scope.var(_get_scale_var_name(var_name)) scope.var(_get_scale_var_name(var_name))
#set var by tensor array or scale #set var by tensor array or scale
_restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place) _restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place)
_restore_var(_get_scale_var_name(var_name), scale, scope, place) _restore_var(_get_scale_var_name(var_name), np.array(scale), scope, place)
# insert dequantize_abs_max op # insert dequantize_abs_max op
for op_node in embedding_node.outputs: for op_node in embedding_node.outputs:
...@@ -231,12 +232,12 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -231,12 +232,12 @@ def _quant_embedding_abs_max(graph, scope, place, config):
graph.safe_remove_nodes(embedding_node) graph.safe_remove_nodes(embedding_node)
def quant_embedding(program, scope, place, config): def quant_embedding(program, place, config, scope=None):
""" """
quant lookup_table op parameters quant lookup_table op parameters
Args: Args:
program(fluid.Program): infer program program(fluid.Program): infer program
scope(fluid.Scope): the scope to store var, usually is fluid.global_scope() scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): config to quant. The keys are 'params_name', 'quantize_type', \ config(dict): config to quant. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \ 'quantize_bits', 'dtype', 'threshold'. \
...@@ -249,6 +250,7 @@ def quant_embedding(program, scope, place, config): ...@@ -249,6 +250,7 @@ def quant_embedding(program, scope, place, config):
""" """
assert isinstance(config, dict), "config must be dict" assert isinstance(config, dict), "config must be dict"
config = _merge_config(copy.deepcopy(default_config), config) config = _merge_config(copy.deepcopy(default_config), config)
scope = fluid.global_scope() if scope is None else scope
graph = IrGraph(core.Graph(program.desc), for_test=True) graph = IrGraph(core.Graph(program.desc), for_test=True)
if config['quantize_type'] == 'abs_max': if config['quantize_type'] == 'abs_max':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册