Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
08a3ed12
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08a3ed12
编写于
3月 10, 2020
作者:
H
hong19860320
提交者:
GitHub
3月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CORE] Support the fully quantized model for MTK and RK NPU (#3096)
上级
89da9953
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
537 addition
and
171 deletion
+537
-171
lite/api/paddle_use_passes.h
lite/api/paddle_use_passes.h
+2
-1
lite/core/mir/CMakeLists.txt
lite/core/mir/CMakeLists.txt
+1
-0
lite/core/mir/graph_visualize_pass.cc
lite/core/mir/graph_visualize_pass.cc
+80
-37
lite/core/mir/quantized_op_attributes_inference_pass.cc
lite/core/mir/quantized_op_attributes_inference_pass.cc
+75
-0
lite/core/mir/quantized_op_attributes_inference_pass.h
lite/core/mir/quantized_op_attributes_inference_pass.h
+36
-0
lite/core/mir/ssa_graph.cc
lite/core/mir/ssa_graph.cc
+12
-3
lite/core/mir/subgraph/subgraph_detector.cc
lite/core/mir/subgraph/subgraph_detector.cc
+31
-0
lite/core/mir/type_layout_cast_pass.cc
lite/core/mir/type_layout_cast_pass.cc
+4
-3
lite/core/mir/type_layout_cast_pass.h
lite/core/mir/type_layout_cast_pass.h
+0
-12
lite/core/mir/type_precision_cast_pass.cc
lite/core/mir/type_precision_cast_pass.cc
+118
-6
lite/core/mir/type_precision_cast_pass.h
lite/core/mir/type_precision_cast_pass.h
+1
-11
lite/core/mir/type_target_cast_pass.cc
lite/core/mir/type_target_cast_pass.cc
+3
-3
lite/core/mir/type_target_cast_pass.h
lite/core/mir/type_target_cast_pass.h
+0
-12
lite/core/op_registry.cc
lite/core/op_registry.cc
+5
-0
lite/core/optimizer.h
lite/core/optimizer.h
+6
-0
lite/core/tensor.cc
lite/core/tensor.cc
+2
-1
lite/kernels/arm/calib_compute.cc
lite/kernels/arm/calib_compute.cc
+107
-34
lite/kernels/arm/calib_compute.h
lite/kernels/arm/calib_compute.h
+4
-2
lite/kernels/arm/layout_compute.cc
lite/kernels/arm/layout_compute.cc
+42
-32
lite/kernels/host/feed_compute.cc
lite/kernels/host/feed_compute.cc
+4
-5
lite/kernels/host/fetch_compute.cc
lite/kernels/host/fetch_compute.cc
+4
-9
未找到文件。
lite/api/paddle_use_passes.h
浏览文件 @
08a3ed12
...
@@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass);
...
@@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
graph_visual
ze
);
USE_MIR_PASS
(
graph_visual
ize_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
...
@@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
...
@@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS
(
npu_subgraph_pass
);
USE_MIR_PASS
(
npu_subgraph_pass
);
USE_MIR_PASS
(
xpu_subgraph_pass
);
USE_MIR_PASS
(
xpu_subgraph_pass
);
USE_MIR_PASS
(
weight_quantization_preprocess_pass
);
USE_MIR_PASS
(
weight_quantization_preprocess_pass
);
USE_MIR_PASS
(
quantized_op_attributes_inference_pass
);
lite/core/mir/CMakeLists.txt
浏览文件 @
08a3ed12
...
@@ -36,6 +36,7 @@ lite_cc_library(mir_passes
...
@@ -36,6 +36,7 @@ lite_cc_library(mir_passes
runtime_context_assign_pass.cc
runtime_context_assign_pass.cc
memory_optimize_pass.cc
memory_optimize_pass.cc
weight_quantization_preprocess_pass.cc
weight_quantization_preprocess_pass.cc
quantized_op_attributes_inference_pass.cc
DEPS mir_pass types context
${
mir_fusers
}
${
mir_subgraphs
}
)
DEPS mir_pass types context
${
mir_fusers
}
${
mir_subgraphs
}
)
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
...
...
lite/core/mir/graph_visualize_pass.cc
浏览文件 @
08a3ed12
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <set>
#include <set>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h"
#include "lite/utils/string.h"
...
@@ -28,56 +29,98 @@ namespace mir {
...
@@ -28,56 +29,98 @@ namespace mir {
using
inference
::
analysis
::
Dot
;
using
inference
::
analysis
::
Dot
;
void
GraphVisualizePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
void
GraphVisualizePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
Visualize
(
graph
.
get
());
V
LOG
(
5
)
<<
"
\n
"
<<
V
isualize
(
graph
.
get
());
}
}
std
::
string
Visualize
(
mir
::
SSAGraph
*
graph
)
{
std
::
string
Visualize
(
mir
::
SSAGraph
*
graph
)
{
std
::
ostringstream
os
;
inference
::
analysis
::
Dot
dot
;
inference
::
analysis
::
Dot
dot
;
auto
string_trunc
=
[](
const
std
::
string
&
str
)
->
std
::
string
{
int
id
=
0
;
const
int
max_disp_size
=
100
;
std
::
set
<
std
::
string
>
exists_args
;
if
(
str
.
length
()
>
max_disp_size
)
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
return
str
.
substr
(
0
,
max_disp_size
)
+
"..."
;
std
::
string
key
;
return
str
;
if
(
node
.
IsArg
())
{
};
key
=
node
.
AsArg
().
name
;
auto
attr_repr
=
[
&
](
const
OpInfo
*
op_info
,
}
else
{
const
std
::
string
&
attr_name
)
->
std
::
string
{
key
=
string_format
(
"%s%d"
,
node
.
AsStmt
().
op_type
().
c_str
(),
id
++
);
std
::
ostringstream
os
;
using
AttrType
=
cpp
::
OpDesc
::
AttrType
;
auto
attr_type
=
op_info
->
GetAttrType
(
attr_name
);
switch
(
attr_type
)
{
case
AttrType
::
INT
:
os
<<
":int:"
<<
std
::
to_string
(
op_info
->
GetAttr
<
int
>
(
attr_name
));
break
;
case
AttrType
::
FLOAT
:
os
<<
":float:"
<<
std
::
to_string
(
op_info
->
GetAttr
<
float
>
(
attr_name
));
break
;
case
AttrType
::
BOOLEAN
:
os
<<
":int:"
<<
std
::
to_string
(
op_info
->
GetAttr
<
bool
>
(
attr_name
));
break
;
case
AttrType
::
STRING
:
os
<<
":string:
\"
"
<<
string_trunc
(
op_info
->
GetAttr
<
std
::
string
>
(
attr_name
))
<<
"
\"
"
;
break
;
case
AttrType
::
FLOATS
:
{
auto
vals
=
op_info
->
GetAttr
<
std
::
vector
<
float
>>
(
attr_name
);
os
<<
":floats: {"
+
Join
(
vals
,
","
)
<<
"}"
;
}
break
;
case
AttrType
::
INTS
:
{
auto
vals
=
op_info
->
GetAttr
<
std
::
vector
<
int
>>
(
attr_name
);
os
<<
":ints: {"
+
Join
(
vals
,
","
)
+
"}"
;
}
break
;
case
AttrType
::
STRINGS
:
{
auto
vals
=
op_info
->
GetAttr
<
std
::
vector
<
std
::
string
>>
(
attr_name
);
os
<<
":strings: {"
+
string_trunc
(
Join
(
vals
,
","
))
<<
"}"
;
}
break
;
default:
os
<<
":Unknow type("
<<
static_cast
<
int
>
(
attr_type
)
<<
")"
;
break
;
}
}
if
(
node
.
IsStmt
())
{
return
os
.
str
();
dot
.
AddNode
(
key
,
};
{
Dot
::
Attr
(
"shape"
,
"box"
),
int
op_idx
=
0
;
Dot
::
Attr
(
"style"
,
"filled"
),
std
::
set
<
std
::
string
>
exists_var_names
;
Dot
::
Attr
(
"color"
,
"black"
),
for
(
auto
&
node
:
graph
->
StmtTopologicalOrder
())
{
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
if
(
!
node
->
IsStmt
())
continue
;
for
(
auto
&
x
:
node
.
inlinks
)
{
auto
op_info
=
node
->
AsStmt
().
op_info
();
auto
name
=
x
->
AsArg
().
name
;
auto
op_type
=
op_info
->
Type
();
if
(
!
exists_args
.
count
(
name
))
{
std
::
string
op_name
=
string_format
(
"%s%d"
,
op_type
.
c_str
(),
op_idx
++
);
dot
.
AddNode
(
name
,
{});
// Add its input&output variables as the Dot nodes
}
dot
.
AddNode
(
op_name
,
dot
.
AddEdge
(
name
,
key
,
{});
{
Dot
::
Attr
(
"shape"
,
"box"
),
exists_args
.
insert
(
name
);
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"color"
,
"black"
),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
for
(
auto
&
x
:
node
->
inlinks
)
{
auto
var_name
=
x
->
AsArg
().
name
;
if
(
!
exists_var_names
.
count
(
var_name
))
{
dot
.
AddNode
(
var_name
,
{});
exists_var_names
.
insert
(
var_name
);
}
}
for
(
auto
&
x
:
node
.
outlinks
)
{
dot
.
AddEdge
(
var_name
,
op_name
,
{});
auto
name
=
x
->
AsArg
().
name
;
}
if
(
!
exists_args
.
count
(
name
)
)
{
for
(
auto
&
x
:
node
->
outlinks
)
{
dot
.
AddNode
(
name
,
{})
;
auto
var_name
=
x
->
AsArg
().
name
;
}
if
(
!
exists_var_names
.
count
(
var_name
))
{
dot
.
Add
Edge
(
key
,
name
,
{});
dot
.
Add
Node
(
var_
name
,
{});
exists_
args
.
insert
(
name
);
exists_
var_names
.
insert
(
var_
name
);
}
}
dot
.
AddEdge
(
op_name
,
var_name
,
{});
}
// Output its all of attributes(name and values)
os
<<
"* "
<<
op_name
<<
"
\n
"
;
const
auto
&
attr_names
=
op_info
->
AttrNames
();
for
(
auto
&
attr_name
:
attr_names
)
{
os
<<
" - "
<<
attr_name
<<
attr_repr
(
op_info
,
attr_name
)
<<
"
\n
"
;
}
}
}
}
os
<<
dot
.
Build
();
auto
res
=
dot
.
Build
();
return
os
.
str
();
// If we use VLOG here, we can not type all graph out.
// So we change VLOG to std::cout.
std
::
cout
<<
"dot:
\n
"
<<
res
<<
std
::
endl
;
return
res
;
}
}
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_MIR_PASS
(
graph_visual
ze
,
paddle
::
lite
::
mir
::
GraphVisualizePass
)
REGISTER_MIR_PASS
(
graph_visual
ize_pass
,
paddle
::
lite
::
mir
::
GraphVisualizePass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)});
lite/core/mir/quantized_op_attributes_inference_pass.cc
0 → 100644
浏览文件 @
08a3ed12
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/quantized_op_attributes_inference_pass.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
QuantizedOpAttributesInferencePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
// Only for fully quantized model which is only supported by MTK and RK NPU.
// Replace the output_scale with the input_scale of the adjacent quantized
// ops, and fix the missing of the attribute 'enable_int8'.
for
(
auto
&
op_node
:
graph
->
StmtTopologicalOrder
())
{
if
(
!
op_node
->
IsStmt
())
continue
;
auto
&
inst
=
op_node
->
AsStmt
();
auto
op_info
=
inst
.
op_info
();
auto
op_type
=
op_info
->
Type
();
if
(
!
op_info
->
HasAttr
(
"input_scale"
))
continue
;
bool
found
=
false
;
float
output_scale
;
for
(
auto
out_var_node
:
op_node
->
outlinks
)
{
CHECK
(
out_var_node
->
IsArg
());
for
(
auto
out_op_node
:
out_var_node
->
outlinks
)
{
CHECK
(
out_op_node
->
IsStmt
());
auto
&
out_inst
=
out_op_node
->
AsStmt
();
auto
out_op_info
=
out_inst
.
op_info
();
if
(
!
out_op_info
->
HasAttr
(
"input_scale"
))
continue
;
auto
input_scale
=
out_op_info
->
GetAttr
<
float
>
(
"input_scale"
);
if
(
!
found
)
{
found
=
true
;
output_scale
=
input_scale
;
}
else
{
CHECK_EQ
(
output_scale
,
input_scale
);
}
}
}
if
(
found
)
{
inst
.
mutable_op_info
()
->
SetAttr
(
"output_scale"
,
output_scale
);
}
if
(
op_info
->
HasAttr
(
"output_scale"
))
{
inst
.
mutable_op_info
()
->
SetAttr
(
"enable_int8"
,
true
);
}
}
VLOG
(
5
)
<<
"
\n
"
<<
Visualize
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
quantized_op_attributes_inference_pass
,
paddle
::
lite
::
mir
::
QuantizedOpAttributesInferencePass
)
.
BindTargets
({
TARGET
(
kNPU
)});
lite/core/mir/quantized_op_attributes_inference_pass.h
0 → 100644
浏览文件 @
08a3ed12
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/types.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
QuantizedOpAttributesInferencePass
:
public
mir
::
StmtPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
lite/core/mir/ssa_graph.cc
浏览文件 @
08a3ed12
...
@@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program,
...
@@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program,
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_update_node_map_
[
name
]
=
arg_node
;
arg_update_node_map_
[
name
]
=
arg_node
;
}
}
if
(
var_types
.
count
(
name
)
&&
!
arg_node
->
arg
()
->
type
)
{
if
(
var_types
.
count
(
name
))
{
arg_node
->
arg
()
->
type
=
LiteType
::
GetTensorTy
(
if
(
!
arg_node
->
arg
()
->
type
)
{
TARGET
(
kUnk
),
var_types
[
name
],
DATALAYOUT
(
kUnk
));
arg_node
->
arg
()
->
type
=
LiteType
::
GetTensorTy
(
TARGET
(
kUnk
),
var_types
[
name
],
DATALAYOUT
(
kUnk
));
}
// Store the original data type of the output tensors for
// type_precision_cast_pass, to keep the consistency between the
// output types of original graph and optimized graph's
if
(
op
->
op_info
()
->
Type
()
==
"fetch"
)
{
op
->
mutable_op_info
()
->
SetAttr
<
int
>
(
"data_type"
,
static_cast
<
int
>
(
var_types
[
name
]));
}
}
}
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
CHECK
(
arg_node
->
IsRoleSet
());
CHECK
(
arg_node
->
IsRoleSet
());
...
...
lite/core/mir/subgraph/subgraph_detector.cc
浏览文件 @
08a3ed12
...
@@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
...
@@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"output_data_names"
,
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"output_data_names"
,
output_var_names
);
output_var_names
);
// Set input/output scale values of input/output var nodes for
// type_precision_cast_pass.
std
::
vector
<
float
>
input_data_scales
;
std
::
vector
<
float
>
output_data_scales
;
for
(
auto
&
var_node
:
input_var_nodes
)
{
auto
any_op_node
=
var_node
->
outlinks
.
front
();
CHECK
(
any_op_node
->
IsStmt
());
auto
&
any_inst
=
any_op_node
->
AsStmt
();
if
(
any_inst
.
op_info
()
->
HasAttr
(
"input_scale"
))
{
input_data_scales
.
push_back
(
any_inst
.
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
));
}
}
for
(
auto
&
var_node
:
output_var_nodes
)
{
auto
any_op_node
=
var_node
->
inlinks
.
front
();
CHECK
(
any_op_node
->
IsStmt
());
auto
&
any_inst
=
any_op_node
->
AsStmt
();
if
(
any_inst
.
op_info
()
->
HasAttr
(
"output_scale"
))
{
output_data_scales
.
push_back
(
any_inst
.
op_info
()
->
GetAttr
<
float
>
(
"output_scale"
));
}
}
if
(
input_data_scales
.
size
()
>
0
)
{
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
float
>>
(
"input_data_scales"
,
input_data_scales
);
}
if
(
output_data_scales
.
size
()
>
0
)
{
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
float
>>
(
"output_data_scales"
,
output_data_scales
);
}
// Set all of the inputs and outputs to the target subgraph op
// Set all of the inputs and outputs to the target subgraph op
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
for
(
auto
&
var_node
:
weight_var_nodes
)
{
for
(
auto
&
var_node
:
weight_var_nodes
)
{
...
...
lite/core/mir/type_layout_cast_pass.cc
浏览文件 @
08a3ed12
...
@@ -20,6 +20,8 @@
...
@@ -20,6 +20,8 @@
#include <vector>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/operators/subgraph_op.h"
#include "lite/utils/string.h"
#include "lite/utils/string.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst(
...
@@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst(
DirectedLink
(
layout_output_arg
,
inst_node
);
DirectedLink
(
layout_output_arg
,
inst_node
);
// reset opdesc and update kernel information
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
UpdateInputs
(
in
->
AsArg
().
name
,
inst_node
->
AsStmt
().
op
().
get
(),
in
->
AsArg
().
name
,
layout_output_name
);
layout_output_name
);
auto
original_selected_kernel
=
auto
original_selected_kernel
=
std
::
move
(
inst_node
->
AsStmt
().
kernels
().
front
());
std
::
move
(
inst_node
->
AsStmt
().
kernels
().
front
());
auto
update_op_info
=
*
inst_node
->
AsStmt
().
op_info
();
auto
update_op_info
=
*
inst_node
->
AsStmt
().
op_info
();
...
...
lite/core/mir/type_layout_cast_pass.h
浏览文件 @
08a3ed12
...
@@ -24,18 +24,6 @@ namespace paddle {
...
@@ -24,18 +24,6 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
}
}
}
class
TypeLayoutTransformPass
:
public
ProgramPass
{
class
TypeLayoutTransformPass
:
public
ProgramPass
{
public:
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
...
...
lite/core/mir/type_precision_cast_pass.cc
浏览文件 @
08a3ed12
...
@@ -20,11 +20,115 @@
...
@@ -20,11 +20,115 @@
#include <vector>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/operators/subgraph_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
// For the subgraph op, we also need to update the attr 'input_data_names' and
// the input variables names of the Ops in the subblock.
void
UpdateInputsForSubgraph
(
OpLite
*
op
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
auto
*
op_desc
=
op
->
mutable_op_info
();
auto
input_data_names
=
op_desc
->
GetAttr
<
std
::
vector
<
std
::
string
>>
(
"input_data_names"
);
std
::
replace
(
input_data_names
.
begin
(),
input_data_names
.
end
(),
from
,
to
);
op_desc
->
SetAttr
(
"input_data_names"
,
input_data_names
);
auto
*
subblock_desc
=
static_cast
<
operators
::
SubgraphOp
*>
(
op
)
->
GetSubBlock
();
CHECK
(
subblock_desc
);
for
(
size_t
i
=
0
;
i
<
subblock_desc
->
OpsSize
();
i
++
)
{
auto
*
subblock_op_desc
=
subblock_desc
->
GetOp
<
cpp
::
OpDesc
>
(
i
);
for
(
auto
&
subblock_op_input
:
*
subblock_op_desc
->
mutable_inputs
())
{
for
(
auto
&
subblock_var_name
:
subblock_op_input
.
second
)
{
if
(
subblock_var_name
==
from
)
{
subblock_var_name
=
to
;
}
}
}
}
}
// Update the input variable names from 'from' to 'to' for the target Op
void
UpdateInputs
(
OpLite
*
op
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
auto
*
op_desc
=
op
->
mutable_op_info
();
auto
op_type
=
op_desc
->
Type
();
for
(
auto
&
op_input
:
*
op_desc
->
mutable_inputs
())
{
for
(
auto
&
var_name
:
op_input
.
second
)
{
if
(
var_name
==
from
)
{
var_name
=
to
;
}
}
}
if
(
op_type
==
"subgraph"
)
{
UpdateInputsForSubgraph
(
op
,
from
,
to
);
}
}
// Infer the scale value for the new calib op from the subgraph op
static
bool
InferScaleFromSubgraph
(
std
::
string
var_name
,
const
OpInfo
*
op_info
,
float
*
scale
,
bool
reverse
=
false
)
{
bool
found
=
false
;
auto
input_or_output_names
=
op_info
->
GetAttr
<
std
::
vector
<
std
::
string
>>
(
reverse
?
"output_data_names"
:
"input_data_names"
);
auto
input_or_output_scales
=
op_info
->
GetAttr
<
std
::
vector
<
float
>>
(
reverse
?
"output_data_scales"
:
"input_data_scales"
);
auto
size
=
input_or_output_names
.
size
();
CHECK
(
size
==
input_or_output_scales
.
size
());
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
if
(
input_or_output_names
[
i
]
==
var_name
)
{
*
scale
=
input_or_output_scales
[
i
];
found
=
true
;
break
;
}
}
return
found
;
}
// Infer the scale value for the new calib op from the input_scale of the
// current op and output_scale of the previous op.
// case 1: prev_op->var_node->op_node(int8->any op, with input_scale).
// case 2: prev_op->var_node->op_node(subgraph op, int8->any, with
// input_data_scales).
// case 3: prev_op(any->int8, with output_scale)->var_node->op_node(fp32->any,
// without input_scale).
// case 4: prev_op(any->int8, subgraph_op, with
// output_data_scales)->var_node->op_node(fp32->any, without input_scale).
static
bool
InferScale
(
Node
*
var_node
,
Node
*
op_node
,
float
*
scale
)
{
bool
found
=
false
;
auto
&
inst
=
op_node
->
AsStmt
();
auto
op_info
=
inst
.
op_info
();
auto
op_type
=
op_info
->
Type
();
auto
var_name
=
var_node
->
AsArg
().
name
;
if
(
op_type
==
"subgraph"
)
{
found
=
InferScaleFromSubgraph
(
var_name
,
op_info
,
scale
,
false
);
}
else
{
if
(
op_info
->
HasAttr
(
"input_scale"
))
{
*
scale
=
op_info
->
GetAttr
<
float
>
(
"input_scale"
);
found
=
true
;
}
else
{
// Obtain the output_scale from one of its previous Ops
auto
prev_op_node
=
var_node
->
inlinks
.
front
();
CHECK
(
prev_op_node
->
IsStmt
());
auto
&
prev_inst
=
prev_op_node
->
AsStmt
();
auto
prev_op_info
=
prev_inst
.
op_info
();
auto
prev_op_type
=
prev_op_info
->
Type
();
if
(
prev_op_type
==
"subgraph"
)
{
found
=
InferScaleFromSubgraph
(
var_name
,
prev_op_info
,
scale
,
true
);
}
else
{
if
(
prev_op_info
->
HasAttr
(
"output_scale"
))
{
*
scale
=
prev_op_info
->
GetAttr
<
float
>
(
"output_scale"
);
found
=
true
;
}
}
}
}
return
found
;
}
void
PrecisionCastPass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
void
PrecisionCastPass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
// Start from inputs of the graph, those should have place set.
// Start from inputs of the graph, those should have place set.
std
::
list
<
Node
*>
nodes
;
std
::
list
<
Node
*>
nodes
;
...
@@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
...
@@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArg
().
type
);
CHECK
(
in
->
AsArg
().
type
);
VLOG
(
4
)
<<
inst
.
picked_kernel
().
name
();
VLOG
(
4
)
<<
inst
.
picked_kernel
().
name
();
if
(
inst
.
op_info
()
->
Type
()
==
"fetch"
)
{
if
(
inst
.
op_info
()
->
HasAttr
(
"data_type"
))
{
auto
data_type
=
static_cast
<
PrecisionType
>
(
inst
.
op_info
()
->
GetAttr
<
int
>
(
"data_type"
));
decl_arg_type
=
LiteType
::
GetTensorTy
(
decl_arg_type
->
target
(),
data_type
,
decl_arg_type
->
layout
());
}
}
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// *decl_arg_type)) {
// *decl_arg_type)) {
if
(
!
PrecisionCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
if
(
!
PrecisionCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
...
@@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from,
...
@@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from,
op_desc
.
SetType
(
cast_type
);
op_desc
.
SetType
(
cast_type
);
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_op_output_name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_op_output_name
});
if
(
inst_node
->
AsStmt
().
op_info
()
->
HasAttr
(
"input_scale"
))
{
float
scale
;
op_desc
.
SetAttr
(
if
(
InferScale
(
in
,
inst_node
,
&
scale
))
{
"scale"
,
inst_node
->
AsStmt
().
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
)
);
op_desc
.
SetAttr
(
"scale"
,
scale
);
}
}
cast_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
()
->
scope
());
cast_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
()
->
scope
());
auto
kernels
=
cast_op
->
CreateKernels
(
valid_places
);
auto
kernels
=
cast_op
->
CreateKernels
(
valid_places
);
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
selected_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
selected_kernels
;
...
@@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
...
@@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
DirectedLink
(
cast_op_output_arg
,
inst_node
);
DirectedLink
(
cast_op_output_arg
,
inst_node
);
// reset opdesc and update kernel information
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
UpdateInputs
(
in
->
AsArg
().
name
,
inst_node
->
AsStmt
().
op
().
get
(),
in
->
AsArg
().
name
,
cast_op_output_name
);
cast_op_output_name
);
// recreate the op
// recreate the op
auto
original_selected_kernel
=
auto
original_selected_kernel
=
...
...
lite/core/mir/type_precision_cast_pass.h
浏览文件 @
08a3ed12
...
@@ -24,17 +24,7 @@ namespace paddle {
...
@@ -24,17 +24,7 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
void
UpdateInputs
(
OpLite
*
op
,
const
std
::
string
&
from
,
const
std
::
string
&
to
);
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
}
}
}
/*
/*
* The pass complement the necessary instruction to make data
* The pass complement the necessary instruction to make data
...
...
lite/core/mir/type_target_cast_pass.cc
浏览文件 @
08a3ed12
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <vector>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/utils/string.h"
#include "lite/utils/string.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in,
...
@@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in,
Node
*
inst_node
,
Node
*
inst_node
,
std
::
string
io_copy_output_name
)
{
std
::
string
io_copy_output_name
)
{
// reset opdesc and update kernel information
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
UpdateInputs
(
in
->
AsArg
().
name
,
inst_node
->
AsStmt
().
op
().
get
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
io_copy_output_name
);
auto
original_selected_kernel
=
auto
original_selected_kernel
=
std
::
move
(
inst_node
->
AsStmt
().
kernels
().
front
());
std
::
move
(
inst_node
->
AsStmt
().
kernels
().
front
());
auto
update_op_info
=
*
inst_node
->
AsStmt
().
op_info
();
auto
update_op_info
=
*
inst_node
->
AsStmt
().
op_info
();
...
...
lite/core/mir/type_target_cast_pass.h
浏览文件 @
08a3ed12
...
@@ -25,18 +25,6 @@ namespace paddle {
...
@@ -25,18 +25,6 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
}
}
}
/*
/*
* IoComplementPass complement the necessary instruction to make data
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
* transferring or transformation between different places.
...
...
lite/core/op_registry.cc
浏览文件 @
08a3ed12
...
@@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry()
...
@@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry()
INIT_FOR
(
kX86
,
kInt64
,
kNCHW
);
INIT_FOR
(
kX86
,
kInt64
,
kNCHW
);
INIT_FOR
(
kARM
,
kFloat
,
kNCHW
);
INIT_FOR
(
kARM
,
kFloat
,
kNCHW
);
INIT_FOR
(
kARM
,
kFloat
,
kNHWC
);
INIT_FOR
(
kARM
,
kInt8
,
kNCHW
);
INIT_FOR
(
kARM
,
kInt8
,
kNCHW
);
INIT_FOR
(
kARM
,
kInt8
,
kNHWC
);
INIT_FOR
(
kARM
,
kAny
,
kNCHW
);
INIT_FOR
(
kARM
,
kAny
,
kNCHW
);
INIT_FOR
(
kARM
,
kAny
,
kAny
);
INIT_FOR
(
kARM
,
kAny
,
kAny
);
INIT_FOR
(
kARM
,
kInt32
,
kNCHW
);
INIT_FOR
(
kARM
,
kInt32
,
kNCHW
);
...
@@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry()
...
@@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry()
INIT_FOR
(
kOpenCL
,
kAny
,
kImageNW
);
INIT_FOR
(
kOpenCL
,
kAny
,
kImageNW
);
INIT_FOR
(
kNPU
,
kFloat
,
kNCHW
);
INIT_FOR
(
kNPU
,
kFloat
,
kNCHW
);
INIT_FOR
(
kNPU
,
kFloat
,
kNHWC
);
INIT_FOR
(
kNPU
,
kInt8
,
kNCHW
);
INIT_FOR
(
kNPU
,
kInt8
,
kNCHW
);
INIT_FOR
(
kNPU
,
kInt8
,
kNHWC
);
INIT_FOR
(
kNPU
,
kAny
,
kNCHW
);
INIT_FOR
(
kNPU
,
kAny
,
kNCHW
);
INIT_FOR
(
kNPU
,
kAny
,
kNHWC
);
INIT_FOR
(
kNPU
,
kAny
,
kAny
);
INIT_FOR
(
kNPU
,
kAny
,
kAny
);
INIT_FOR
(
kXPU
,
kFloat
,
kNCHW
);
INIT_FOR
(
kXPU
,
kFloat
,
kNCHW
);
...
...
lite/core/optimizer.h
浏览文件 @
08a3ed12
...
@@ -75,6 +75,12 @@ class Optimizer {
...
@@ -75,6 +75,12 @@ class Optimizer {
(defined LITE_WITH_ARM)
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass"
,
//
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
#endif
"quantized_op_attributes_inference_pass"
,
// Only for fully
// quantized model, infer
// the output scale and
// fix the attribute
// 'enable_int8' for all
// of the quantized ops.
"npu_subgraph_pass"
,
"npu_subgraph_pass"
,
"xpu_subgraph_pass"
,
"xpu_subgraph_pass"
,
"bm_subgraph_pass"
,
"bm_subgraph_pass"
,
...
...
lite/core/tensor.cc
浏览文件 @
08a3ed12
...
@@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) {
...
@@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) {
target_
=
other
.
target_
;
target_
=
other
.
target_
;
lod_
=
other
.
lod_
;
lod_
=
other
.
lod_
;
memory_size_
=
other
.
memory_size_
;
memory_size_
=
other
.
memory_size_
;
precision_
=
other
.
precision_
;
}
}
void
TensorLite
::
CopyDataFrom
(
const
TensorLite
&
other
)
{
void
TensorLite
::
CopyDataFrom
(
const
TensorLite
&
other
)
{
...
@@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
...
@@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_
=
other
.
target_
;
target_
=
other
.
target_
;
lod_
=
other
.
lod_
;
lod_
=
other
.
lod_
;
memory_size_
=
other
.
memory_size_
;
memory_size_
=
other
.
memory_size_
;
precision_
=
other
.
precision
()
;
precision_
=
other
.
precision
_
;
buffer_
->
CopyDataFrom
(
*
other
.
buffer_
,
memory_size_
);
buffer_
->
CopyDataFrom
(
*
other
.
buffer_
,
memory_size_
);
}
}
...
...
lite/kernels/arm/calib_compute.cc
浏览文件 @
08a3ed12
...
@@ -23,24 +23,24 @@ namespace lite {
...
@@ -23,24 +23,24 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
arm
{
namespace
arm
{
void
CalibComputeFp32ToInt8
::
Run
()
{
template
<
DataLayoutType
DLType
>
auto
&
param
=
this
->
Param
<
operators
::
CalibParam
>
();
void
CalibComputeFp32ToInt8
<
DLType
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
operators
::
CalibParam
>();
std
::
vector
<
float
>
scale
=
{
param
.
scale
};
std
::
vector
<
float
>
scale
=
{
param
.
scale
};
const
auto
*
din
=
param
.
input
->
data
<
float
>
();
const
auto
*
din
=
param
.
input
->
template
data
<
float
>();
auto
*
dout
=
param
.
output
->
mutable_data
<
signed
char
>
();
auto
*
dout
=
param
.
output
->
template
mutable_data
<
signed
char
>();
lite
::
arm
::
math
::
fp32_to_int8
(
lite
::
arm
::
math
::
fp32_to_int8
(
din
,
dout
,
scale
.
data
(),
1
,
1
,
param
.
input
->
numel
());
din
,
dout
,
scale
.
data
(),
1
,
1
,
param
.
input
->
numel
());
return
;
}
}
void
CalibComputeInt8ToFp32
::
Run
()
{
template
<
DataLayoutType
DLType
>
auto
&
param
=
this
->
Param
<
operators
::
CalibParam
>
();
void
CalibComputeInt8ToFp32
<
DLType
>::
Run
()
{
const
auto
*
din
=
param
.
input
->
data
<
signed
char
>
();
auto
&
param
=
this
->
template
Param
<
operators
::
CalibParam
>();
const
auto
*
din
=
param
.
input
->
template
data
<
signed
char
>();
std
::
vector
<
float
>
scale
=
{
param
.
scale
};
std
::
vector
<
float
>
scale
=
{
param
.
scale
};
auto
*
dout
=
param
.
output
->
mutable_data
<
float
>
();
auto
*
dout
=
param
.
output
->
template
mutable_data
<
float
>();
lite
::
arm
::
math
::
int8_to_fp32
(
lite
::
arm
::
math
::
int8_to_fp32
(
din
,
dout
,
scale
.
data
(),
1
,
1
,
param
.
input
->
numel
());
din
,
dout
,
scale
.
data
(),
1
,
1
,
param
.
input
->
numel
());
return
;
}
}
}
// namespace arm
}
// namespace arm
...
@@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() {
...
@@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() {
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
calib
,
REGISTER_LITE_KERNEL
(
kARM
,
calib
,
kInt8
,
kARM
,
kNCHW
,
kInt8
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
,
kNCHW
,
fp32_to_int8
)
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNCHW
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib
,
REGISTER_LITE_KERNEL
(
kARM
,
calib
,
kInt8
,
kARM
,
kNCHW
,
kInt8
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
,
kNCHW
,
int8_to_fp32
)
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNCHW
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
REGISTER_LITE_KERNEL
(
kInt8
,
calib
,
kNCHW
,
kARM
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
,
kInt8
,
fp32_to_int8
)
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNHWC
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNHWC
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNCHW
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
REGISTER_LITE_KERNEL
(
kARM
,
calib_once
,
kInt8
,
kARM
,
kNCHW
,
kInt8
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
,
kNCHW
,
int8_to_fp32
)
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNCHW
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
Finalize
();
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNHWC
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNHWC
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
lite/kernels/arm/calib_compute.h
浏览文件 @
08a3ed12
...
@@ -21,8 +21,9 @@ namespace lite {
...
@@ -21,8 +21,9 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
arm
{
namespace
arm
{
template
<
DataLayoutType
DLType
>
class
CalibComputeFp32ToInt8
class
CalibComputeFp32ToInt8
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
>
{
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
,
DLType
>
{
public:
public:
using
param_t
=
operators
::
CalibParam
;
using
param_t
=
operators
::
CalibParam
;
...
@@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8
...
@@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8
private:
private:
};
};
template
<
DataLayoutType
DLType
>
class
CalibComputeInt8ToFp32
class
CalibComputeInt8ToFp32
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
>
{
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
,
DLType
>
{
public:
public:
using
param_t
=
operators
::
CalibParam
;
using
param_t
=
operators
::
CalibParam
;
...
...
lite/kernels/arm/layout_compute.cc
浏览文件 @
08a3ed12
...
@@ -20,40 +20,50 @@ namespace lite {
...
@@ -20,40 +20,50 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
arm
{
namespace
arm
{
#define NCHWTONHWC(type) \
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
if (input_dim.size() != 4) { \
<< "NCHW to NHWC should guarantee that the input dims should be 4"; \
LOG(WARNING) << "NCHW to NHWC should guarantee that the input dims " \
int n = input_dim[0]; \
"should be 4, but received " \
int c = input_dim[1]; \
<< input_dim.size(); \
int h = input_dim[2]; \
param.y->ShareDataWith(*param.x); \
int w = input_dim[3]; \
return; \
param.y->Resize({n, h, w, c}); \
} \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
int n = input_dim[0]; \
if (c == 1) { \
int c = input_dim[1]; \
memcpy(output, input, sizeof(type) * n * h * w); \
int h = input_dim[2]; \
return; \
int w = input_dim[3]; \
} \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output);
lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output);
#define NHWCTONCHW(type) \
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
if (input_dim.size() != 4) { \
<< "NHWC to NCHW should guarantee that the input dims should be 4"; \
LOG(WARNING) << "NHWC to NCHW should guarantee that the input dims " \
int n = input_dim[0]; \
"should be 4, but received " \
int h = input_dim[1]; \
<< input_dim.size(); \
int w = input_dim[2]; \
param.y->ShareDataWith(*param.x); \
int c = input_dim[3]; \
return; \
param.y->Resize({n, c, h, w}); \
} \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
int n = input_dim[0]; \
if (c == 1) { \
int h = input_dim[1]; \
memcpy(output, input, sizeof(type) * n * h * w); \
int w = input_dim[2]; \
return; \
int c = input_dim[3]; \
} \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output);
lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output);
template
<
>
template
<
>
...
...
lite/kernels/host/feed_compute.cc
浏览文件 @
08a3ed12
...
@@ -20,8 +20,7 @@ namespace lite {
...
@@ -20,8 +20,7 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
host
{
namespace
host
{
class
FeedCompute
class
FeedCompute
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
)
>
{
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
public:
public:
using
param_t
=
operators
::
FeedParam
;
using
param_t
=
operators
::
FeedParam
;
...
@@ -40,7 +39,7 @@ class FeedCompute
...
@@ -40,7 +39,7 @@ class FeedCompute
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
k
Any
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
feed
,
kHost
,
kAny
,
k
NCHW
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
)
,
PRECISION
(
kAny
)
)})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
)
,
PRECISION
(
kAny
)
)})
.
Finalize
();
.
Finalize
();
lite/kernels/host/fetch_compute.cc
浏览文件 @
08a3ed12
...
@@ -20,8 +20,7 @@ namespace lite {
...
@@ -20,8 +20,7 @@ namespace lite {
namespace
kernels
{
namespace
kernels
{
namespace
host
{
namespace
host
{
class
FetchCompute
class
FetchCompute
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
)
>
{
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
public:
public:
using
param_t
=
operators
::
FeedParam
;
using
param_t
=
operators
::
FeedParam
;
...
@@ -43,11 +42,7 @@ class FetchCompute
...
@@ -43,11 +42,7 @@ class FetchCompute
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
fetch
,
kHost
,
kAny
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
))})
{
LiteType
::
GetTensorTy
(
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
))})
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
-
1
)})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
-
1
)})
.
Finalize
();
.
Finalize
();
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录