Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
735864a0
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
735864a0
编写于
1月 06, 2020
作者:
M
MyPandaShaoxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: add dyanmic quant fuse pass
上级
0a075279
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
295 addition
and
6 deletion
+295
-6
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
+15
-1
lite/core/mir/fusion/quant_dequant_op_fuser.cc
lite/core/mir/fusion/quant_dequant_op_fuser.cc
+243
-3
lite/core/mir/fusion/quant_dequant_op_fuser.h
lite/core/mir/fusion/quant_dequant_op_fuser.h
+31
-0
lite/operators/fake_quantize_range_abs_max.cc
lite/operators/fake_quantize_range_abs_max.cc
+2
-0
lite/operators/fake_quantize_range_abs_max.h
lite/operators/fake_quantize_range_abs_max.h
+4
-2
未找到文件。
lite/core/mir/fusion/quant_dequant_fuse_pass.cc
浏览文件 @
735864a0
...
@@ -27,10 +27,24 @@ namespace mir {
...
@@ -27,10 +27,24 @@ namespace mir {
void
QuantDequantFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
void
QuantDequantFusePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
// delete quant node
// delete quant node
std
::
vector
<
std
::
string
>
quant_op_types
=
{
std
::
vector
<
std
::
string
>
quant_op_types
=
{
"fake_quantize_range_abs_max"
,
"fake_quantize_moving_average_abs_max"
};
"fake_quantize_abs_max"
,
"fake_quantize_range_abs_max"
,
"fake_quantize_moving_average_abs_max"
};
/*
for (auto& op_type : {"conv2d", "mul", "depthwise_conv2d"}) {
for (int i = 5; i >= 1; --i){
fusion::DynamicQuantDequantOpFuser fuser("fake_quantize_abs_max", op_type,
i);
fuser(graph.get());
}
}
*/
for
(
auto
&
op_type
:
quant_op_types
)
{
for
(
auto
&
op_type
:
quant_op_types
)
{
fusion
::
DeleteQuantOpFuser
fuser
(
op_type
);
fusion
::
DeleteQuantOpFuser
fuser
(
op_type
);
fuser
(
graph
.
get
());
fuser
(
graph
.
get
());
fusion
::
DeleteDynamicQuantOpFuser
dfuser
(
op_type
);
dfuser
(
graph
.
get
());
}
}
// fuse quantized node and dequant node
// fuse quantized node and dequant node
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.cc
浏览文件 @
735864a0
...
@@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
...
@@ -77,6 +77,55 @@ cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
return
op_desc
;
return
op_desc
;
}
}
void
DeleteDynamicQuantOpFuser
::
BuildPattern
()
{
auto
*
input_act_node
=
VarNode
(
"input_act_node"
)
->
assert_is_op_input
(
quant_op_type_
,
"X"
);
auto
*
quant_node
=
OpNode
(
"quant_node"
,
quant_op_type_
)
->
assert_is_op
(
quant_op_type_
);
auto
*
output_scale_node
=
VarNode
(
"output_scale_node"
)
->
assert_is_op_output
(
quant_op_type_
,
"OutScale"
);
auto
*
output_act_node
=
VarNode
(
"output_act_node"
)
->
assert_is_op_output
(
quant_op_type_
,
"Out"
);
quant_node
->
LinksFrom
({
input_act_node
});
output_scale_node
->
LinksFrom
({
quant_node
});
output_act_node
->
LinksFrom
({
quant_node
});
VLOG
(
4
)
<<
"DeleteQuantOpFuser BuildPattern quant_op_type:"
<<
quant_op_type_
;
}
void
DeleteDynamicQuantOpFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
auto
*
input_act_node
=
matched
.
at
(
"input_act_node"
);
auto
*
quant_node
=
matched
.
at
(
"quant_node"
);
auto
*
output_scale_node
=
matched
.
at
(
"output_scale_node"
);
auto
*
output_act_node
=
matched
.
at
(
"output_act_node"
);
// obtain values, save values and relink node
int
bit_length
=
quant_node
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
scope
=
quant_node
->
stmt
()
->
op
()
->
scope
();
auto
*
scale_tensor
=
scope
->
FindVar
(
output_scale_node
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
scale_value
=
scale_tensor
->
data
<
float
>
()[
0
]
/
range
;
auto
outlinks
=
output_act_node
->
outlinks
;
for
(
auto
*
quantized_node
:
outlinks
)
{
auto
*
op_desc
=
quantized_node
->
stmt
()
->
mutable_op_info
();
op_desc
->
SetAttr
<
int
>
(
"bit_length"
,
bit_length
);
IR_NODE_LINK_TO
(
input_act_node
,
quantized_node
)
}
// delete nodes and edges
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{
quant_node
,
output_scale_node
,
output_act_node
};
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
cpp
::
OpDesc
DeleteDynamicQuantOpFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
return
op_desc
;
}
void
DequantOpFuser
::
BuildPattern
()
{
void
DequantOpFuser
::
BuildPattern
()
{
std
::
string
weight_name
=
""
;
std
::
string
weight_name
=
""
;
if
(
quantized_op_type_
==
"conv2d"
||
if
(
quantized_op_type_
==
"conv2d"
||
...
@@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -130,8 +179,11 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
auto
&
valid_places
=
quantized_op
->
stmt
()
->
op
()
->
valid_places
();
int
bit_length
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
bit_length
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
float
input_scale
=
float
input_scale
=
0
;
if
(
quantized_op
->
stmt
()
->
op_info
()
->
HasAttr
(
"input_scale"
))
{
input_scale
=
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
);
quantized_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
);
}
float
max_range
=
dequant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
float
max_range
=
dequant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
float
whole_weight_scale
=
float
whole_weight_scale
=
static_cast
<
float
>
(
range
*
range
)
/
max_range
/
range
;
static_cast
<
float
>
(
range
*
range
)
/
max_range
/
range
;
...
@@ -163,7 +215,9 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
...
@@ -163,7 +215,9 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale
.
push_back
(
whole_weight_scale
);
weight_scale
.
push_back
(
whole_weight_scale
);
}
}
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
op_desc
.
SetAttr
(
"enable_int8"
,
true
);
if
(
quantized_op
->
stmt
()
->
op_info
()
->
HasAttr
(
"input_scale"
))
{
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
op_desc
.
SetAttr
(
"input_scale"
,
input_scale
);
}
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
// change the weight from the float type to int8 type.
// change the weight from the float type to int8 type.
...
@@ -464,6 +518,192 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
...
@@ -464,6 +518,192 @@ cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp
::
OpDesc
op_desc
;
cpp
::
OpDesc
op_desc
;
return
op_desc
;
return
op_desc
;
}
}
// ================dynamic quant fuse==============
// #define DYNAMIC_RANGE
void
DynamicQuantDequantOpFuser
::
BuildPattern
()
{
const
int
kNumFields
=
5
;
const
int
kQuantizedWeightOffset
=
0
;
const
int
kQuantizedOpOffset
=
1
;
const
int
kQuantizedOpOutOffset
=
2
;
const
int
kDequantOpOffset
=
3
;
const
int
kDequantOpOutOffset
=
4
;
std
::
string
weight_name
=
""
;
if
(
op_type_
==
"conv2d"
||
op_type_
==
"depthwise_conv2d"
)
{
weight_name
=
"Filter"
;
}
else
{
weight_name
=
"Y"
;
}
auto
*
quant_op_input
=
VarNode
(
"quant_op_input"
)
->
assert_is_op_input
(
quant_type_
,
"X"
)
->
AsInput
();
#ifdef DYNAMIC_RANGE
auto
*
quant_op_in_scale
=
VarNode
(
"quant_op_in_scale"
)
->
assert_is_op_input
(
quant_type_
,
"InScale"
)
->
AsIntermediate
();
#endif
auto
*
quant_op
=
OpNode
(
"quant_op"
,
quant_type_
)
->
assert_is_op
(
quant_type_
)
->
AsIntermediate
();
auto
*
quant_op_out_scale
=
VarNode
(
"quant_op_out_scale"
)
->
assert_is_op_output
(
quant_type_
,
"OutScale"
)
->
assert_is_op_input
(
"fake_dequantize_max_abs"
,
"Scale"
)
->
AsIntermediate
();
auto
*
quant_op_out
=
VarNode
(
"quant_op_out"
)
->
assert_is_op_output
(
quant_type_
,
"Out"
)
->
assert_is_op_input
(
op_type_
)
->
AsIntermediate
();
std
::
vector
<
PMNode
*>
nodes
;
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
nodes
.
push_back
(
VarNode
(
string_format
(
"quantized_op_weight%d"
,
i
))
->
assert_is_op_input
(
op_type_
,
weight_name
)
->
AsInput
());
nodes
.
push_back
(
OpNode
(
string_format
(
"quantized_op%d"
,
i
),
op_type_
)
->
assert_is_op
(
op_type_
)
->
AsIntermediate
());
nodes
.
push_back
(
VarNode
(
string_format
(
"quantized_op_out%d"
,
i
))
->
assert_is_op_output
(
op_type_
)
->
assert_is_op_input
(
"fake_dequantize_max_abs"
,
"X"
)
->
AsIntermediate
());
nodes
.
push_back
(
OpNode
(
string_format
(
"dequant_op%d"
,
i
),
"fake_dequantize_max_abs"
)
->
assert_is_op
(
"fake_dequantize_max_abs"
)
->
AsIntermediate
());
nodes
.
push_back
(
VarNode
(
string_format
(
"dequant_op_out%d"
,
i
))
->
assert_is_op_output
(
"fake_dequantize_max_abs"
,
"Out"
)
->
AsOutput
());
}
#ifdef DYNAMIC_RANGE
quant_op
->
LinksFrom
({
quant_op_input
,
quant_op_in_scale
});
#endif
quant_op
->
LinksFrom
({
quant_op_input
});
quant_op_out
->
LinksFrom
({
quant_op
});
quant_op_out_scale
->
LinksFrom
({
quant_op
});
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]
->
LinksFrom
(
{
quant_op_out
,
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
]});
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]});
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kQuantizedOpOutOffset
],
quant_op_out_scale
});
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
LinksFrom
(
{
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]});
}
}
void
DynamicQuantDequantOpFuser
::
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
{
const
int
kNumFields
=
5
;
const
int
kQuantizedWeightOffset
=
0
;
const
int
kQuantizedOpOffset
=
1
;
const
int
kDequantOpOffset
=
3
;
const
int
kDequantOpOutOffset
=
4
;
auto
*
quant_op_input
=
matched
.
at
(
"quant_op_input"
);
#ifdef DYNAMIC_RANGE
auto
*
quant_op_in_scale
=
matched
.
at
(
"quant_op_in_scale"
);
#endif
auto
*
quant_op
=
matched
.
at
(
"quant_op"
);
std
::
vector
<
Node
*>
nodes
;
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
nodes
.
push_back
(
matched
.
at
(
string_format
(
"quantized_op_weight%d"
,
i
)));
nodes
.
push_back
(
matched
.
at
(
string_format
(
"quantized_op%d"
,
i
)));
nodes
.
push_back
(
matched
.
at
(
string_format
(
"quantized_op_out%d"
,
i
)));
nodes
.
push_back
(
matched
.
at
(
string_format
(
"dequant_op%d"
,
i
)));
nodes
.
push_back
(
matched
.
at
(
string_format
(
"dequant_op_out%d"
,
i
)));
}
int
bit_length
=
quant_op
->
stmt
()
->
op_info
()
->
GetAttr
<
int
>
(
"bit_length"
);
auto
*
scope
=
quant_op
->
stmt
()
->
op
()
->
scope
();
auto
&
valid_places
=
quant_op
->
stmt
()
->
op
()
->
valid_places
();
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
#ifdef DYNAMIC_RANGE
auto
input_scale_t
=
scope
->
FindVar
(
quant_op_in_scale
->
arg
()
->
name
)
->
GetMutable
<
lite
::
Tensor
>
();
float
input_scale
=
input_scale_t
->
data
<
float
>
()[
0
]
/
range
;
VLOG
(
4
)
<<
"range: "
<<
range
<<
" input_scale: "
<<
input_scale
;
#endif
for
(
int
i
=
0
;
i
<
times_
;
i
++
)
{
float
max_range
=
nodes
[
i
*
kNumFields
+
kDequantOpOffset
]
->
stmt
()
->
op_info
()
->
GetAttr
<
float
>
(
"max_range"
);
// weight_scale = max(abs(weight))
float
whole_weight_scale
=
static_cast
<
float
>
(
range
*
range
)
/
max_range
/
range
;
cpp
::
OpDesc
op_desc
=
*
nodes
[
i
*
kNumFields
+
kQuantizedOpOffset
]
->
stmt
()
->
op_info
();
auto
quantized_weight_var_name
=
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
]
->
arg
()
->
name
;
auto
quantized_weight_t
=
scope
->
FindVar
(
quantized_weight_var_name
)
->
GetMutable
<
lite
::
Tensor
>
();
std
::
vector
<
float
>
weight_scale
;
int
weight_scale_size
;
if
(
op_type_
==
"conv2d"
||
op_type_
==
"depthwise_conv2d"
)
{
op_desc
.
SetInput
(
"Input"
,
{
matched
.
at
(
"quant_op_input"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Output"
,
{
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
arg
()
->
name
});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size
=
quantized_weight_t
->
dims
()[
0
];
}
else
if
(
op_type_
==
"mul"
)
{
op_desc
.
SetInput
(
"X"
,
{
matched
.
at
(
"quant_op_input"
)
->
arg
()
->
name
});
op_desc
.
SetOutput
(
"Out"
,
{
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]
->
arg
()
->
name
});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size
=
quantized_weight_t
->
dims
()[
1
];
}
for
(
int
i
=
0
;
i
<
weight_scale_size
;
i
++
)
{
weight_scale
.
push_back
(
whole_weight_scale
);
}
// op_desc.SetAttr("enable_int8", true);
// op_desc.SetAttr("input_scale", input_scale);
op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
Tensor
temp_tensor
;
temp_tensor
.
CopyDataFrom
(
*
quantized_weight_t
);
float
*
temp_data
=
temp_tensor
.
mutable_data
<
float
>
();
size_t
weight_num
=
quantized_weight_t
->
data_size
();
quantized_weight_t
->
set_persistable
(
true
);
#ifdef LITE_WITH_FPGA
float
*
quantized_weight_data
=
quantized_weight_t
->
mutable_data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
weight_num
;
i
++
)
{
quantized_weight_data
[
i
]
=
temp_data
[
i
]
*
whole_weight_scale
;
}
quantized_weight_t
->
set_precision
(
PRECISION
(
kFloat
));
#else
int8_t
*
quantized_weight_data
=
quantized_weight_t
->
mutable_data
<
int8_t
>
();
for
(
size_t
i
=
0
;
i
<
weight_num
;
i
++
)
{
quantized_weight_data
[
i
]
=
static_cast
<
int8_t
>
(
temp_data
[
i
]);
}
quantized_weight_t
->
set_precision
(
PRECISION
(
kInt8
));
#endif
auto
quantized_op
=
LiteOpRegistry
::
Global
().
Create
(
op_type_
);
quantized_op
->
Attach
(
op_desc
,
scope
);
auto
*
new_op_node
=
graph
->
GraphCreateInstructNode
(
quantized_op
,
valid_places
);
IR_NODE_LINK_TO
(
quant_op_input
,
new_op_node
);
IR_NODE_LINK_TO
(
nodes
[
i
*
kNumFields
+
kQuantizedWeightOffset
],
new_op_node
);
IR_NODE_LINK_TO
(
new_op_node
,
nodes
[
i
*
kNumFields
+
kDequantOpOutOffset
]);
}
}
cpp
::
OpDesc
DynamicQuantDequantOpFuser
::
GenOpDesc
(
const
key2nodes_t
&
matched
)
{
cpp
::
OpDesc
op_desc
;
return
op_desc
;
}
}
// namespace fusion
}
// namespace fusion
}
// namespace mir
}
// namespace mir
...
...
lite/core/mir/fusion/quant_dequant_op_fuser.h
浏览文件 @
735864a0
...
@@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase {
...
@@ -52,6 +52,19 @@ class DeleteQuantOpFuser : public FuseBase {
private:
private:
std
::
string
quant_op_type_
{};
std
::
string
quant_op_type_
{};
};
};
class
DeleteDynamicQuantOpFuser
:
public
FuseBase
{
public:
explicit
DeleteDynamicQuantOpFuser
(
const
std
::
string
&
quant_op_type
)
:
quant_op_type_
(
quant_op_type
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
private:
std
::
string
quant_op_type_
{};
};
/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs.
/* DequantOpFuser process conv2d/depthwise_conv2d/mul + fake_dequantize_max_abs.
*/
*/
...
@@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase {
...
@@ -106,6 +119,24 @@ class DeleteQuantDequantOpFuser : public FuseBase {
private:
private:
std
::
string
quantized_op_type_
{};
std
::
string
quantized_op_type_
{};
};
};
// dynamic quantdequant op fuser
class
DynamicQuantDequantOpFuser
:
public
FuseBase
{
public:
explicit
DynamicQuantDequantOpFuser
(
const
std
::
string
&
quantized_op_type
,
const
std
::
string
&
op_type
,
int
i
)
:
op_type_
(
op_type
),
quant_type_
(
quantized_op_type
),
times_
(
i
)
{}
void
BuildPattern
()
override
;
void
InsertNewNode
(
SSAGraph
*
graph
,
const
key2nodes_t
&
matched
)
override
;
private:
cpp
::
OpDesc
GenOpDesc
(
const
key2nodes_t
&
matched
)
override
;
private:
std
::
string
op_type_
{};
std
::
string
quant_type_
{};
int
times_
{
1
};
};
}
// namespace fusion
}
// namespace fusion
}
// namespace mir
}
// namespace mir
...
...
lite/operators/fake_quantize_range_abs_max.cc
浏览文件 @
735864a0
...
@@ -23,3 +23,5 @@ namespace operators {} // namespace operators
...
@@ -23,3 +23,5 @@ namespace operators {} // namespace operators
REGISTER_LITE_OP
(
fake_quantize_range_abs_max
,
REGISTER_LITE_OP
(
fake_quantize_range_abs_max
,
paddle
::
lite
::
operators
::
FakeQuantizeRangeMaxAbsOpLite
);
paddle
::
lite
::
operators
::
FakeQuantizeRangeMaxAbsOpLite
);
REGISTER_LITE_OP
(
fake_quantize_abs_max
,
paddle
::
lite
::
operators
::
FakeQuantizeRangeMaxAbsOpLite
);
lite/operators/fake_quantize_range_abs_max.h
浏览文件 @
735864a0
...
@@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
...
@@ -40,13 +40,15 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
if
(
op_desc
.
HasInput
(
"InScale"
))
{
auto
in_scale
=
op_desc
.
Input
(
"InScale"
).
front
();
auto
in_scale
=
op_desc
.
Input
(
"InScale"
).
front
();
param_
.
in_scale
=
scope
->
FindVar
(
in_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
}
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
auto
out_scale
=
op_desc
.
Output
(
"OutScale"
).
front
();
auto
out_scale
=
op_desc
.
Output
(
"OutScale"
).
front
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
x
=
scope
->
FindVar
(
x
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
in_scale
=
scope
->
FindVar
(
in_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out
=
scope
->
FindVar
(
out
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out_scale
=
scope
->
FindVar
(
out_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
out_scale
=
scope
->
FindVar
(
out_scale
)
->
GetMutable
<
lite
::
Tensor
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录