Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
451896fc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
451896fc
编写于
1月 19, 2019
作者:
W
WangZhen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init quantization.
上级
0b6447a4
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
391 addition
and
6 deletion
+391
-6
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+1
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+4
-0
paddle/fluid/pybind/ir.cc
paddle/fluid/pybind/ir.cc
+6
-0
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+6
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+13
-1
python/paddle/fluid/contrib/slim/graph/graph.py
python/paddle/fluid/contrib/slim/graph/graph.py
+74
-5
python/paddle/fluid/contrib/slim/quantization/quantization_performer.py
...fluid/contrib/slim/quantization/quantization_performer.py
+287
-0
未找到文件。
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
451896fc
...
...
@@ -233,3 +233,4 @@ USE_PASS(sequential_execution_pass);
USE_PASS
(
all_reduce_deps_pass
);
USE_PASS
(
modify_op_lock_and_record_event_pass
);
USE_PASS
(
lock_free_optimize_pass
);
USE_PASS
(
graph_to_program_pass
);
paddle/fluid/framework/ir/pass.cc
浏览文件 @
451896fc
...
...
@@ -28,10 +28,14 @@ std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
PADDLE_ENFORCE
(
graph
->
Has
(
attr
),
"Required graph atrribute %s not set."
,
attr
);
}
auto
*
native_graph
=
graph
.
get
();
auto
applied_graph
=
ApplyImpl
(
std
::
move
(
graph
));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE
(
!
HasCircle
(
*
applied_graph
),
"Illegal Pass. Generated graph shouldn't has cycle."
);
PADDLE_ENFORCE
(
applied_graph
.
get
()
==
native_graph
,
"Pass::Apply() cannot delete the passed graph and shouldn't "
"return a new graph.(For the need of pybind11)"
);
applied_
=
true
;
return
applied_graph
;
}
...
...
paddle/fluid/pybind/ir.cc
浏览文件 @
451896fc
...
...
@@ -42,6 +42,7 @@ void BindGraph(py::module *m) {
.
def
(
"get_float"
,
&
Graph
::
Get
<
float
>
)
.
def
(
"get_double"
,
&
Graph
::
Get
<
double
>
)
.
def
(
"get_string"
,
&
Graph
::
Get
<
std
::
string
>
)
.
def
(
"get_program"
,
&
Graph
::
Get
<
ProgramDesc
>
)
.
def
(
"set"
,
[](
Graph
&
self
,
const
std
::
string
&
attr_name
,
int
attr
)
{
return
self
.
Set
(
attr_name
,
new
int
(
attr
));
})
.
def
(
"set"
,
...
...
@@ -57,6 +58,11 @@ void BindGraph(py::module *m) {
[](
Graph
&
self
,
const
std
::
string
&
attr_name
,
double
attr
)
{
return
self
.
Set
(
attr_name
,
new
double
(
attr
));
})
.
def
(
"set"
,
[](
Graph
&
self
,
const
std
::
string
&
attr_name
,
const
ProgramDesc
&
attr
)
{
return
self
.
Set
(
attr_name
,
new
ProgramDesc
(
attr
));
})
.
def
(
"erase"
,
&
Graph
::
Erase
)
.
def
(
"nodes"
,
&
Graph
::
Nodes
,
return_value_policy
::
reference
)
.
def
(
"create_var_node"
,
...
...
paddle/fluid/pybind/protobuf.cc
浏览文件 @
451896fc
...
...
@@ -229,6 +229,12 @@ void BindBlockDesc(pybind11::module *m) {
void
BindVarDsec
(
pybind11
::
module
*
m
)
{
pybind11
::
class_
<
pd
::
VarDesc
>
var_desc
(
*
m
,
"VarDesc"
,
""
);
var_desc
.
def
(
"__init__"
,
[](
pd
::
VarDesc
&
self
,
const
pybind11
::
bytes
&
binary_str
)
{
std
::
string
str
(
binary_str
);
new
(
&
self
)
pd
::
VarDesc
(
str
);
},
pybind11
::
return_value_policy
::
reference
)
.
def
(
"name"
,
&
pd
::
VarDesc
::
Name
,
pybind11
::
return_value_policy
::
reference
)
.
def
(
"set_name"
,
&
pd
::
VarDesc
::
SetName
)
.
def
(
"set_shape"
,
&
pd
::
VarDesc
::
SetShape
)
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
451896fc
...
...
@@ -786,9 +786,20 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"disable_profiler"
,
platform
::
DisableProfiler
);
m
.
def
(
"is_profiler_enabled"
,
platform
::
IsProfileEnabled
);
m
.
def
(
"reset_profiler"
,
platform
::
ResetProfiler
);
m
.
def
(
"get_pass"
,
[](
const
py
::
bytes
&
binary_str
)
{
std
::
string
pass_type
(
binary_str
);
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
return
std
::
shared_ptr
<
framework
::
ir
::
Pass
>
(
std
::
move
(
pass
));
});
py
::
class_
<
ir
::
Pass
,
std
::
shared_ptr
<
ir
::
Pass
>>
pass
(
m
,
"Pass"
);
pass
.
def
(
py
::
init
())
.
def
(
"has"
,
&
ir
::
Pass
::
Has
)
.
def
(
"set_program"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
attr_name
,
const
ProgramDesc
&
attr
)
{
return
self
.
Set
(
attr_name
,
new
ProgramDesc
(
attr
));
})
.
def
(
"set_str"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
const
std
::
string
&
attr
)
{
...
...
@@ -796,11 +807,12 @@ All parameter, weight, gradient are variables in Paddle.
})
.
def
(
"set_int"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
int
val
)
{
self
.
Set
<
const
int
>
(
name
,
new
int
(
val
));
})
.
def
(
"get_program"
,
&
ir
::
Pass
::
Get
<
ProgramDesc
>
)
.
def
(
"type"
,
&
ir
::
Pass
::
Type
)
.
def
(
"apply"
,
[](
ir
::
Pass
&
self
,
std
::
shared_ptr
<
ir
::
Graph
>
graph
)
{
std
::
unique_ptr
<
ir
::
Graph
>
origin_graph
(
graph
.
get
());
auto
optim_graph
=
self
.
Apply
(
std
::
move
(
origin_graph
));
graph
.
reset
(
optim_graph
.
release
()
);
optim_graph
.
release
(
);
});
py
::
class_
<
ir
::
PassBuilder
,
std
::
shared_ptr
<
ir
::
PassBuilder
>>
pb
(
...
...
python/paddle/fluid/contrib/slim/graph/graph.py
浏览文件 @
451896fc
...
...
@@ -13,8 +13,81 @@
# limitations under the License.
from
....framework
import
Program
from
....framework
import
Block
from
....
import
core
__all__
=
[
'Graph'
,
'ImitationGraph'
,
'IRGraph'
]
__all__
=
[
'Graph'
,
'ImitationGraph'
,
'PyGraph'
]
class
PyGraph
(
object
):
"""
PyGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def
__init__
(
self
,
graph
):
assert
isinstance
(
graph
,
core
.
Graph
),
'graph must be the instance of core.Graph.'
self
.
graph
=
graph
def
all_parameters
(
self
):
params
=
[]
for
node
in
self
.
graph
.
nodes
():
if
node
.
is_var
()
and
node
.
var
().
persistable
():
params
.
append
(
node
)
return
params
def
all_vars
(
self
):
return
[
node
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_var
()]
def
all_ops
(
self
):
return
[
node
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_op
()]
def
create_param_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
var_desc
=
core
.
VarDesc
(
name
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_shape
(
shape
)
var_desc
.
set_dtype
(
var_dtype
)
var_desc
.
set_persistable
(
True
)
return
self
.
graph
.
create_var_node
(
var_desc
)
def
create_var_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
var_desc
=
core
.
VarDesc
(
name
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_shape
(
shape
)
var_desc
.
set_dtype
(
var_dtype
)
return
self
.
graph
.
create_var_node
(
var_desc
)
def
create_var_node_from_desc
(
self
,
var_desc
):
return
self
.
graph
.
create_var_node
(
var_desc
)
def
create_op_node
(
self
,
op_type
,
attrs
,
inputs
,
outputs
):
op_desc
=
core
.
OpDesc
()
op_desc
.
set_type
(
op_type
)
for
attr
,
value
in
attrs
.
iteritems
():
self
.
_update_desc_attr
(
op_desc
,
attr
,
value
)
for
input_name
,
var_node
in
inputs
.
iteritems
():
op_desc
.
set_input
(
input_name
,
[
var_node
.
name
()])
for
output_name
,
var_node
in
outputs
.
iteritems
():
op_desc
.
set_output
(
output_name
,
[
var_node
.
name
()])
return
self
.
graph
.
create_op_node
(
op_desc
)
def
create_op_node_from_desc
(
self
,
op_desc
):
return
self
.
graph
.
create_op_node
(
op_desc
)
def
_update_desc_attr
(
self
,
desc
,
name
,
val
):
"""
Update the value of desc's attribute by attribute's name.
"""
if
isinstance
(
val
,
Block
):
desc
.
set_block_attr
(
name
,
val
.
desc
)
elif
isinstance
(
val
,
list
)
and
val
and
all
(
isinstance
(
v
,
Block
)
for
v
in
val
):
desc
.
set_blocks_attr
(
name
,
[
v
.
desc
for
v
in
val
])
elif
isinstance
(
val
,
core
.
BlockDesc
)
or
\
isinstance
(
val
,
core
.
ProgramDesc
):
desc
.
set_serialized_attr
(
name
,
val
.
serialize_to_string
())
else
:
desc
.
_set_attr
(
name
,
val
)
class
Graph
(
object
):
...
...
@@ -39,7 +112,3 @@ class ImitationGraph(Graph):
def
all_parameters
(
self
):
return
self
.
program
.
global_block
().
all_parameters
()
class
IRGraph
(
Graph
):
pass
python/paddle/fluid/contrib/slim/quantization/quantization_performer.py
0 → 100644
浏览文件 @
451896fc
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
numpy
as
np
from
....
import
core
from
....initializer
import
Constant
from
....
import
unique_name
from
..graph
import
PyGraph
class
QuantizationPerformer
(
object
):
def
__init__
(
self
,
weight_bits
=
8
,
activation_bits
=
8
,
activation_quantize_type
=
'abs_max'
,
weight_quantize_type
=
'abs_max'
,
window_size
=
10000
):
"""
Convert and rewrite the IRGraph according to weight and
activation quantization type.
Args:
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation,
now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode,
the quantization scale will be calculated dynamically each step
in both training and testing period. If use 'range_abs_max',
a static quantization scale will be calculated during training
and used in inference.
weight_quantize_type (str): quantization type for weights,
support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
.. code-block:: python
# the original graph will be rewrite, if you don't want to
# change it, please clone at first.
# graph = graph.clone()
from paddle.fluid.contrib.slim import *
from paddle.fluid.contrib.quantize import *
graph = IRGraph(program)
performer = QuantizationPerformer()
performer.quantize_transform(graph)
"""
self
.
weight_bits
=
weight_bits
self
.
activation_bits
=
activation_bits
quant_type
=
[
'abs_max'
,
'range_abs_max'
]
if
activation_quantize_type
not
in
quant_type
:
raise
ValueError
(
"Unknown activation_quantize_type : '%s'. It can only be "
,
"'abs_max' or 'range_abs_max'."
,
str
(
activation_quantize_type
))
if
weight_quantize_type
not
in
quant_type
:
raise
ValueError
(
"Unknown weight_quantize_type: '%s'. It can only be "
,
"'abs_max' or 'range_abs_max'."
,
str
(
weight_quantize_type
))
self
.
activation_quantize_type
=
activation_quantize_type
self
.
weight_quantize_type
=
weight_quantize_type
self
.
window_size
=
window_size
self
.
need_inited_outer
=
collections
.
OrderedDict
()
self
.
quantizable_ops
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
self
.
quantizable_grad_ops
=
[
'%s_grad'
%
(
op
)
for
op
in
self
.
quantizable_ops
]
self
.
fake_quant_op_types
=
[
'fake_quantize_abs_max'
,
'fake_quantize_range_abs_max'
]
self
.
fake_dequant_op_types
=
[
'fake_dequantize_max_abs'
]
self
.
is_test
=
None
self
.
global_step
=
None
def
quantize_transform
(
self
,
graph
,
is_test
):
self
.
need_inited_outer
.
clear
()
self
.
is_test
=
is_test
assert
isinstance
(
graph
,
PyGraph
),
'graph must be the instance of PyGraph.'
# marked the variable which has been dequantized.
dequantized_vars
=
collections
.
OrderedDict
()
params
=
[
p
.
name
()
for
p
in
graph
.
all_parameters
()]
def
_transform_forward
(
graph
,
op
):
for
var_node
in
op
.
inputs
:
if
var_node
.
name
()
in
dequantized_vars
:
dequant_var_node
=
dequantized_vars
[
var_node
.
name
()]
else
:
quant_bits
=
self
.
weight_bits
if
var_node
.
name
()
in
params
\
else
self
.
activation_bits
quant_type
=
self
.
weight_quantize_type
if
var_node
.
name
()
\
in
params
else
self
.
activation_quantize_type
quant_var_node
,
scale_var_node
=
self
.
_insert_quant_op
(
graph
,
var_node
,
quant_bits
,
quant_type
)
dequant_var_node
=
self
.
_insert_dequant_op
(
graph
,
quant_var_node
,
scale_var_node
,
quant_bits
)
dequantized_vars
[
var_node
.
name
()]
=
dequant_var_node
self
.
_update_input
(
var_node
,
dequant_var_node
,
op
)
if
not
self
.
is_test
:
self
.
_create_global_step
(
graph
)
ops
=
graph
.
all_ops
()
for
op
in
ops
:
# transform the forward graph
if
op
.
name
()
in
self
.
quantizable_ops
:
_transform_forward
(
graph
,
op
)
# rename the inputs of backward op
if
op
.
name
()
in
self
.
quantizable_grad_ops
:
_transform_backward
(
graph
,
op
)
return
self
.
need_inited_outer
def
_insert_quant_op
(
self
,
graph
,
var_node
,
quant_bits
,
quant_type
):
"""
Insert fake_quantize_op in the graph.
"""
if
quant_type
==
'abs_max'
:
return
self
.
_insert_quant_abs_max_op
(
graph
,
var_node
,
quant_bits
)
elif
quant_type
==
'range_abs_max'
:
return
self
.
_inser_quant_range_abs_max_op
(
graph
,
var_node
,
quant_bits
)
def
_insert_quant_abs_max_op
(
self
,
graph
,
var_node
,
quant_bits
):
"""
Insert fake_quantize_abs_max op in the graph.
"""
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
quant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
shape
=
var_node
.
var
().
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
scale_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
shape
=
var_node
.
var
().
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_abs_max'
,
attrs
=
{
'bit_length'
:
quant_bits
},
inputs
=
{
'X'
:
var_node
},
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_var_node
})
self
.
_link_to
(
var_node
,
quant_op_node
)
self
.
_link_to
(
quant_op_node
,
quant_var_node
)
self
.
_link_to
(
quant_op_node
,
scale_var_node
)
return
quant_var_node
,
scale_var_node
def
_insert_quant_range_abs_max_op
(
self
,
graph
,
var_node
,
quant_bits
):
"""
Insert fake_quantize_range_abs_max on the graph.
"""
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
quant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_quantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
shape
=
var_node
.
var
().
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
scale_in_node
=
graph
.
create_param_node
(
name
=
self
.
_quantized_scale_name
(
var_node
.
name
()),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
1
],
var_dtype
=
var_node
.
var
().
dtype
())
self
.
need_inited_outer
[
scale_in_node
.
var
()]
=
Constant
(
value
=
0.001
)
scale_out_node
=
graph
.
create_var_node_from_desc
(
scale_in_node
.
var
())
inputs
=
{
'X'
:
var_node
,
'InScale'
:
scale_in_node
}
outputs
=
{
'Out'
:
quant_var_node
,
'OutScale'
:
scale_out_node
}
if
not
self
.
is_test
:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node
=
graph
.
create_param_node
(
name
=
unique_name
.
generate
(
'scales'
),
var_type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
[
self
.
window_size
],
var_dtype
=
var_node
.
var
().
dtype
())
self
.
need_inited_outer
[
scales_node
.
var
()]
=
Constant
(
value
=
0
)
inputs
[
'Iter'
]
=
self
.
global_step
outputs
[
'OutScales'
]
=
scales_node
attrs
=
{
'window_size'
:
self
.
window_size
,
'bit_length'
:
quant_bits
,
'is_test'
:
self
.
is_test
}
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_range_abs_max'
,
attrs
=
attrs
,
inputs
=
inputs
,
outputs
=
outputs
)
self
.
_link_to
(
var_node
,
quant_op_node
)
self
.
_link_to
(
scale_in_node
,
quant_op_node
)
self
.
_link_to
(
quant_op_node
,
quant_var_node
)
self
.
_link_to
(
quant_op_node
,
scale_out_node
)
if
not
self
.
is_test
:
self
.
_link_to
(
self
.
global_step
,
quant_op_node
)
self
.
_link_to
(
quant_op_node
,
scales_node
)
return
quant_var_node
,
scale_out_node
def
_insert_dequant_op
(
self
,
graph
,
var_node
,
scale_var_node
,
quant_bits
):
"""
Insert fake_dequantize_op in the graph.
"""
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
dequant_var_node
=
graph
.
create_var_node
(
name
=
self
.
_dequantized_var_name
(
var_node
.
name
()),
var_type
=
var_node
.
var
().
type
(),
shape
=
var_node
.
var
().
shape
(),
var_dtype
=
var_node
.
var
().
dtype
())
max_range
=
(
1
<<
(
quant_bits
-
1
))
-
1
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_dequantize_max_abs'
,
attrs
=
{
'max_range'
:
float
(
max_range
)},
inputs
=
{
'X'
:
var_node
,
'Scale'
:
scale_var_node
},
outputs
=
{
'Out'
:
dequant_var_node
})
self
.
_link_to
(
var_node
,
dequant_op_node
)
self
.
_link_to
(
scale_var_node
,
dequant_op_node
)
self
.
_link_to
(
dequant_op_node
,
dequant_var_node
)
return
dequant_var_node
def
_update_input
(
self
,
old_input_node
,
new_input_node
,
op_node
):
old_input_node
.
outputs
.
remove
(
op_node
)
op_node
.
inputs
.
remove
(
old_input_node
)
new_input_node
.
outputs
.
append
(
op_node
)
op_node
.
inputs
.
append
(
new_input_node
)
def
_link_to
(
node_in
,
node_out
):
node_in
.
outputs
.
append
(
node_out
)
node_out
.
inputs
.
append
(
node_in
)
def
_quantized_var_name
(
self
,
var_name
):
"""
Return quantized variable name for the input `var_name`.
"""
return
"%s.quantized"
%
(
var_name
)
def
_dequantized_var_name
(
self
,
var_name
):
"""
Return dequantized variable name for the input `var_name`.
"""
return
"%s.dequantized"
%
(
var_name
)
def
_quantized_scale_name
(
self
,
var_name
):
"""
Return quantized variable name for the input `var_name`.
"""
return
"%s.scale"
%
(
var_name
)
def
_original_var_name
(
self
,
var_name
):
"""
Return the original variable name.
"""
if
var_name
.
endswith
(
'.quantized.dequantized'
):
return
var_name
[:
-
len
(
'.quantized.dequantized'
)]
if
var_name
.
endswith
(
'.quantized'
):
return
var_name
[:
-
len
(
'.quantized'
)]
if
var_name
.
endswith
(
'.dequantized'
):
return
var_name
[:
-
len
(
'.dequantized'
)]
if
var_name
.
endswith
(
'.scale'
):
return
var_name
[:
-
len
(
'.scale'
)]
else
:
return
var_name
def
_is_float
(
self
,
v
):
return
isinstance
(
v
,
float
)
or
isinstance
(
v
,
np
.
float32
)
def
_quant
(
self
,
x
,
scale
,
num_bits
):
y
=
np
.
round
(
x
/
scale
*
((
1
<<
(
num_bits
-
1
))
-
1
))
return
y
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录