Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
08a3ed12
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08a3ed12
编写于
3月 10, 2020
作者:
H
hong19860320
提交者:
GitHub
3月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CORE] Support the fully quantized model for MTK and RK NPU (#3096)
上级
89da9953
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
537 addition
and
171 deletion
+537
-171
lite/api/paddle_use_passes.h
lite/api/paddle_use_passes.h
+2
-1
lite/core/mir/CMakeLists.txt
lite/core/mir/CMakeLists.txt
+1
-0
lite/core/mir/graph_visualize_pass.cc
lite/core/mir/graph_visualize_pass.cc
+80
-37
lite/core/mir/quantized_op_attributes_inference_pass.cc
lite/core/mir/quantized_op_attributes_inference_pass.cc
+75
-0
lite/core/mir/quantized_op_attributes_inference_pass.h
lite/core/mir/quantized_op_attributes_inference_pass.h
+36
-0
lite/core/mir/ssa_graph.cc
lite/core/mir/ssa_graph.cc
+12
-3
lite/core/mir/subgraph/subgraph_detector.cc
lite/core/mir/subgraph/subgraph_detector.cc
+31
-0
lite/core/mir/type_layout_cast_pass.cc
lite/core/mir/type_layout_cast_pass.cc
+4
-3
lite/core/mir/type_layout_cast_pass.h
lite/core/mir/type_layout_cast_pass.h
+0
-12
lite/core/mir/type_precision_cast_pass.cc
lite/core/mir/type_precision_cast_pass.cc
+118
-6
lite/core/mir/type_precision_cast_pass.h
lite/core/mir/type_precision_cast_pass.h
+1
-11
lite/core/mir/type_target_cast_pass.cc
lite/core/mir/type_target_cast_pass.cc
+3
-3
lite/core/mir/type_target_cast_pass.h
lite/core/mir/type_target_cast_pass.h
+0
-12
lite/core/op_registry.cc
lite/core/op_registry.cc
+5
-0
lite/core/optimizer.h
lite/core/optimizer.h
+6
-0
lite/core/tensor.cc
lite/core/tensor.cc
+2
-1
lite/kernels/arm/calib_compute.cc
lite/kernels/arm/calib_compute.cc
+107
-34
lite/kernels/arm/calib_compute.h
lite/kernels/arm/calib_compute.h
+4
-2
lite/kernels/arm/layout_compute.cc
lite/kernels/arm/layout_compute.cc
+42
-32
lite/kernels/host/feed_compute.cc
lite/kernels/host/feed_compute.cc
+4
-5
lite/kernels/host/fetch_compute.cc
lite/kernels/host/fetch_compute.cc
+4
-9
未找到文件。
lite/api/paddle_use_passes.h
浏览文件 @
08a3ed12
...
...
@@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS
(
io_copy_kernel_pick_pass
);
USE_MIR_PASS
(
argument_type_display_pass
);
USE_MIR_PASS
(
runtime_context_assign_pass
);
USE_MIR_PASS
(
graph_visual
ze
);
USE_MIR_PASS
(
graph_visual
ize_pass
);
USE_MIR_PASS
(
lite_conv_bn_fuse_pass
);
USE_MIR_PASS
(
lite_fc_fuse_pass
);
...
...
@@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS
(
npu_subgraph_pass
);
USE_MIR_PASS
(
xpu_subgraph_pass
);
USE_MIR_PASS
(
weight_quantization_preprocess_pass
);
USE_MIR_PASS
(
quantized_op_attributes_inference_pass
);
lite/core/mir/CMakeLists.txt
浏览文件 @
08a3ed12
...
...
@@ -36,6 +36,7 @@ lite_cc_library(mir_passes
runtime_context_assign_pass.cc
memory_optimize_pass.cc
weight_quantization_preprocess_pass.cc
quantized_op_attributes_inference_pass.cc
DEPS mir_pass types context
${
mir_fusers
}
${
mir_subgraphs
}
)
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
...
...
lite/core/mir/graph_visualize_pass.cc
浏览文件 @
08a3ed12
...
...
@@ -18,6 +18,7 @@
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h"
...
...
@@ -28,56 +29,98 @@ namespace mir {
using
inference
::
analysis
::
Dot
;
void
GraphVisualizePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
Visualize
(
graph
.
get
());
V
LOG
(
5
)
<<
"
\n
"
<<
V
isualize
(
graph
.
get
());
}
std
::
string
Visualize
(
mir
::
SSAGraph
*
graph
)
{
std
::
ostringstream
os
;
inference
::
analysis
::
Dot
dot
;
int
id
=
0
;
std
::
set
<
std
::
string
>
exists_args
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
std
::
string
key
;
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
().
name
;
}
else
{
key
=
string_format
(
"%s%d"
,
node
.
AsStmt
().
op_type
().
c_str
(),
id
++
);
auto
string_trunc
=
[](
const
std
::
string
&
str
)
->
std
::
string
{
const
int
max_disp_size
=
100
;
if
(
str
.
length
()
>
max_disp_size
)
return
str
.
substr
(
0
,
max_disp_size
)
+
"..."
;
return
str
;
};
auto
attr_repr
=
[
&
](
const
OpInfo
*
op_info
,
const
std
::
string
&
attr_name
)
->
std
::
string
{
std
::
ostringstream
os
;
using
AttrType
=
cpp
::
OpDesc
::
AttrType
;
auto
attr_type
=
op_info
->
GetAttrType
(
attr_name
);
switch
(
attr_type
)
{
case
AttrType
::
INT
:
os
<<
":int:"
<<
std
::
to_string
(
op_info
->
GetAttr
<
int
>
(
attr_name
));
break
;
case
AttrType
::
FLOAT
:
os
<<
":float:"
<<
std
::
to_string
(
op_info
->
GetAttr
<
float
>
(
attr_name
));
break
;
case
AttrType
::
BOOLEAN
:
os
<<
":int:"
<<
std
::
to_string
(
op_info
->
GetAttr
<
bool
>
(
attr_name
));
break
;
case
AttrType
::
STRING
:
os
<<
":string:
\"
"
<<
string_trunc
(
op_info
->
GetAttr
<
std
::
string
>
(
attr_name
))
<<
"
\"
"
;
break
;
case
AttrType
::
FLOATS
:
{
auto
vals
=
op_info
->
GetAttr
<
std
::
vector
<
float
>>
(
attr_name
);
os
<<
":floats: {"
+
Join
(
vals
,
","
)
<<
"}"
;
}
break
;
case
AttrType
::
INTS
:
{
auto
vals
=
op_info
->
GetAttr
<
std
::
vector
<
int
>>
(
attr_name
);
os
<<
":ints: {"
+
Join
(
vals
,
","
)
+
"}"
;
}
break
;
case
AttrType
::
STRINGS
:
{
auto
vals
=
op_info
->
GetAttr
<
std
::
vector
<
std
::
string
>>
(
attr_name
);
os
<<
":strings: {"
+
string_trunc
(
Join
(
vals
,
","
))
<<
"}"
;
}
break
;
default:
os
<<
":Unknow type("
<<
static_cast
<
int
>
(
attr_type
)
<<
")"
;
break
;
}
if
(
node
.
IsStmt
())
{
dot
.
AddNode
(
key
,
{
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"color"
,
"black"
),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
for
(
auto
&
x
:
node
.
inlinks
)
{
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
dot
.
AddEdge
(
name
,
key
,
{});
exists_args
.
insert
(
name
);
return
os
.
str
();
};
int
op_idx
=
0
;
std
::
set
<
std
::
string
>
exists_var_names
;
for
(
auto
&
node
:
graph
->
StmtTopologicalOrder
())
{
if
(
!
node
->
IsStmt
())
continue
;
auto
op_info
=
node
->
AsStmt
().
op_info
();
auto
op_type
=
op_info
->
Type
();
std
::
string
op_name
=
string_format
(
"%s%d"
,
op_type
.
c_str
(),
op_idx
++
);
// Add its input&output variables as the Dot nodes
dot
.
AddNode
(
op_name
,
{
Dot
::
Attr
(
"shape"
,
"box"
),
Dot
::
Attr
(
"style"
,
"filled"
),
Dot
::
Attr
(
"color"
,
"black"
),
Dot
::
Attr
(
"fillcolor"
,
"yellow"
)});
for
(
auto
&
x
:
node
->
inlinks
)
{
auto
var_name
=
x
->
AsArg
().
name
;
if
(
!
exists_var_names
.
count
(
var_name
))
{
dot
.
AddNode
(
var_name
,
{});
exists_var_names
.
insert
(
var_name
);
}
for
(
auto
&
x
:
node
.
outlinks
)
{
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
)
)
{
dot
.
AddNode
(
name
,
{})
;
}
dot
.
Add
Edge
(
key
,
name
,
{});
exists_
args
.
insert
(
name
);
dot
.
AddEdge
(
var_name
,
op_name
,
{});
}
for
(
auto
&
x
:
node
->
outlinks
)
{
auto
var_name
=
x
->
AsArg
().
name
;
if
(
!
exists_var_names
.
count
(
var_name
))
{
dot
.
Add
Node
(
var_
name
,
{});
exists_
var_names
.
insert
(
var_
name
);
}
dot
.
AddEdge
(
op_name
,
var_name
,
{});
}
// Output its all of attributes(name and values)
os
<<
"* "
<<
op_name
<<
"
\n
"
;
const
auto
&
attr_names
=
op_info
->
AttrNames
();
for
(
auto
&
attr_name
:
attr_names
)
{
os
<<
" - "
<<
attr_name
<<
attr_repr
(
op_info
,
attr_name
)
<<
"
\n
"
;
}
}
auto
res
=
dot
.
Build
();
// If we use VLOG here, we can not type all graph out.
// So we change VLOG to std::cout.
std
::
cout
<<
"dot:
\n
"
<<
res
<<
std
::
endl
;
return
res
;
os
<<
dot
.
Build
();
return
os
.
str
();
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
graph_visual
ze
,
paddle
::
lite
::
mir
::
GraphVisualizePass
)
REGISTER_MIR_PASS
(
graph_visual
ize_pass
,
paddle
::
lite
::
mir
::
GraphVisualizePass
)
.
BindTargets
({
TARGET
(
kAny
)});
lite/core/mir/quantized_op_attributes_inference_pass.cc
0 → 100644
浏览文件 @
08a3ed12
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/quantized_op_attributes_inference_pass.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
QuantizedOpAttributesInferencePass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
// Only for fully quantized model which is only supported by MTK and RK NPU.
// Replace the output_scale with the input_scale of the adjacent quantized
// ops, and fix the missing of the attribute 'enable_int8'.
for
(
auto
&
op_node
:
graph
->
StmtTopologicalOrder
())
{
if
(
!
op_node
->
IsStmt
())
continue
;
auto
&
inst
=
op_node
->
AsStmt
();
auto
op_info
=
inst
.
op_info
();
auto
op_type
=
op_info
->
Type
();
if
(
!
op_info
->
HasAttr
(
"input_scale"
))
continue
;
bool
found
=
false
;
float
output_scale
;
for
(
auto
out_var_node
:
op_node
->
outlinks
)
{
CHECK
(
out_var_node
->
IsArg
());
for
(
auto
out_op_node
:
out_var_node
->
outlinks
)
{
CHECK
(
out_op_node
->
IsStmt
());
auto
&
out_inst
=
out_op_node
->
AsStmt
();
auto
out_op_info
=
out_inst
.
op_info
();
if
(
!
out_op_info
->
HasAttr
(
"input_scale"
))
continue
;
auto
input_scale
=
out_op_info
->
GetAttr
<
float
>
(
"input_scale"
);
if
(
!
found
)
{
found
=
true
;
output_scale
=
input_scale
;
}
else
{
CHECK_EQ
(
output_scale
,
input_scale
);
}
}
}
if
(
found
)
{
inst
.
mutable_op_info
()
->
SetAttr
(
"output_scale"
,
output_scale
);
}
if
(
op_info
->
HasAttr
(
"output_scale"
))
{
inst
.
mutable_op_info
()
->
SetAttr
(
"enable_int8"
,
true
);
}
}
VLOG
(
5
)
<<
"
\n
"
<<
Visualize
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
quantized_op_attributes_inference_pass
,
paddle
::
lite
::
mir
::
QuantizedOpAttributesInferencePass
)
.
BindTargets
({
TARGET
(
kNPU
)});
lite/core/mir/quantized_op_attributes_inference_pass.h
0 → 100644
浏览文件 @
08a3ed12
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/types.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
QuantizedOpAttributesInferencePass
:
public
mir
::
StmtPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
lite/core/mir/ssa_graph.cc
浏览文件 @
08a3ed12
...
...
@@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program,
arg_node
->
AsArg
(
name
,
node_storage_
.
size
()
-
1
);
arg_update_node_map_
[
name
]
=
arg_node
;
}
if
(
var_types
.
count
(
name
)
&&
!
arg_node
->
arg
()
->
type
)
{
arg_node
->
arg
()
->
type
=
LiteType
::
GetTensorTy
(
TARGET
(
kUnk
),
var_types
[
name
],
DATALAYOUT
(
kUnk
));
if
(
var_types
.
count
(
name
))
{
if
(
!
arg_node
->
arg
()
->
type
)
{
arg_node
->
arg
()
->
type
=
LiteType
::
GetTensorTy
(
TARGET
(
kUnk
),
var_types
[
name
],
DATALAYOUT
(
kUnk
));
}
// Store the original data type of the output tensors for
// type_precision_cast_pass, to keep the consistency between the
// output types of original graph and optimized graph's
if
(
op
->
op_info
()
->
Type
()
==
"fetch"
)
{
op
->
mutable_op_info
()
->
SetAttr
<
int
>
(
"data_type"
,
static_cast
<
int
>
(
var_types
[
name
]));
}
}
if
(
is_weights
(
name
))
arg_node
->
AsArg
().
is_weight
=
true
;
CHECK
(
arg_node
->
IsRoleSet
());
...
...
lite/core/mir/subgraph/subgraph_detector.cc
浏览文件 @
08a3ed12
...
...
@@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
std
::
string
>>
(
"output_data_names"
,
output_var_names
);
// Set input/output scale values of input/output var nodes for
// type_precision_cast_pass.
std
::
vector
<
float
>
input_data_scales
;
std
::
vector
<
float
>
output_data_scales
;
for
(
auto
&
var_node
:
input_var_nodes
)
{
auto
any_op_node
=
var_node
->
outlinks
.
front
();
CHECK
(
any_op_node
->
IsStmt
());
auto
&
any_inst
=
any_op_node
->
AsStmt
();
if
(
any_inst
.
op_info
()
->
HasAttr
(
"input_scale"
))
{
input_data_scales
.
push_back
(
any_inst
.
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
));
}
}
for
(
auto
&
var_node
:
output_var_nodes
)
{
auto
any_op_node
=
var_node
->
inlinks
.
front
();
CHECK
(
any_op_node
->
IsStmt
());
auto
&
any_inst
=
any_op_node
->
AsStmt
();
if
(
any_inst
.
op_info
()
->
HasAttr
(
"output_scale"
))
{
output_data_scales
.
push_back
(
any_inst
.
op_info
()
->
GetAttr
<
float
>
(
"output_scale"
));
}
}
if
(
input_data_scales
.
size
()
>
0
)
{
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
float
>>
(
"input_data_scales"
,
input_data_scales
);
}
if
(
output_data_scales
.
size
()
>
0
)
{
subgraph_op_desc
.
SetAttr
<
std
::
vector
<
float
>>
(
"output_data_scales"
,
output_data_scales
);
}
// Set all of the inputs and outputs to the target subgraph op
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
for
(
auto
&
var_node
:
weight_var_nodes
)
{
...
...
lite/core/mir/type_layout_cast_pass.cc
浏览文件 @
08a3ed12
...
...
@@ -20,6 +20,8 @@
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/operators/subgraph_op.h"
#include "lite/utils/string.h"
namespace
paddle
{
...
...
@@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst(
DirectedLink
(
layout_output_arg
,
inst_node
);
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
in
->
AsArg
().
name
,
layout_output_name
);
UpdateInputs
(
inst_node
->
AsStmt
().
op
().
get
(),
in
->
AsArg
().
name
,
layout_output_name
);
auto
original_selected_kernel
=
std
::
move
(
inst_node
->
AsStmt
().
kernels
().
front
());
auto
update_op_info
=
*
inst_node
->
AsStmt
().
op_info
();
...
...
lite/core/mir/type_layout_cast_pass.h
浏览文件 @
08a3ed12
...
...
@@ -24,18 +24,6 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
}
}
}
class
TypeLayoutTransformPass
:
public
ProgramPass
{
public:
void
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
override
;
...
...
lite/core/mir/type_precision_cast_pass.cc
浏览文件 @
08a3ed12
...
...
@@ -20,11 +20,115 @@
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/operators/subgraph_op.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
// For the subgraph op, we also need to update the attr 'input_data_names' and
// the input variables names of the Ops in the subblock.
void
UpdateInputsForSubgraph
(
OpLite
*
op
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
auto
*
op_desc
=
op
->
mutable_op_info
();
auto
input_data_names
=
op_desc
->
GetAttr
<
std
::
vector
<
std
::
string
>>
(
"input_data_names"
);
std
::
replace
(
input_data_names
.
begin
(),
input_data_names
.
end
(),
from
,
to
);
op_desc
->
SetAttr
(
"input_data_names"
,
input_data_names
);
auto
*
subblock_desc
=
static_cast
<
operators
::
SubgraphOp
*>
(
op
)
->
GetSubBlock
();
CHECK
(
subblock_desc
);
for
(
size_t
i
=
0
;
i
<
subblock_desc
->
OpsSize
();
i
++
)
{
auto
*
subblock_op_desc
=
subblock_desc
->
GetOp
<
cpp
::
OpDesc
>
(
i
);
for
(
auto
&
subblock_op_input
:
*
subblock_op_desc
->
mutable_inputs
())
{
for
(
auto
&
subblock_var_name
:
subblock_op_input
.
second
)
{
if
(
subblock_var_name
==
from
)
{
subblock_var_name
=
to
;
}
}
}
}
}
// Update the input variable names from 'from' to 'to' for the target Op
void
UpdateInputs
(
OpLite
*
op
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
auto
*
op_desc
=
op
->
mutable_op_info
();
auto
op_type
=
op_desc
->
Type
();
for
(
auto
&
op_input
:
*
op_desc
->
mutable_inputs
())
{
for
(
auto
&
var_name
:
op_input
.
second
)
{
if
(
var_name
==
from
)
{
var_name
=
to
;
}
}
}
if
(
op_type
==
"subgraph"
)
{
UpdateInputsForSubgraph
(
op
,
from
,
to
);
}
}
// Infer the scale value for the new calib op from the subgraph op
static
bool
InferScaleFromSubgraph
(
std
::
string
var_name
,
const
OpInfo
*
op_info
,
float
*
scale
,
bool
reverse
=
false
)
{
bool
found
=
false
;
auto
input_or_output_names
=
op_info
->
GetAttr
<
std
::
vector
<
std
::
string
>>
(
reverse
?
"output_data_names"
:
"input_data_names"
);
auto
input_or_output_scales
=
op_info
->
GetAttr
<
std
::
vector
<
float
>>
(
reverse
?
"output_data_scales"
:
"input_data_scales"
);
auto
size
=
input_or_output_names
.
size
();
CHECK
(
size
==
input_or_output_scales
.
size
());
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
if
(
input_or_output_names
[
i
]
==
var_name
)
{
*
scale
=
input_or_output_scales
[
i
];
found
=
true
;
break
;
}
}
return
found
;
}
// Infer the scale value for the new calib op from the input_scale of the
// current op and output_scale of the previous op.
// case 1: prev_op->var_node->op_node(int8->any op, with input_scale).
// case 2: prev_op->var_node->op_node(subgraph op, int8->any, with
// input_data_scales).
// case 3: prev_op(any->int8, with output_scale)->var_node->op_node(fp32->any,
// without input_scale).
// case 4: prev_op(any->int8, subgraph_op, with
// output_data_scales)->var_node->op_node(fp32->any, without input_scale).
static
bool
InferScale
(
Node
*
var_node
,
Node
*
op_node
,
float
*
scale
)
{
bool
found
=
false
;
auto
&
inst
=
op_node
->
AsStmt
();
auto
op_info
=
inst
.
op_info
();
auto
op_type
=
op_info
->
Type
();
auto
var_name
=
var_node
->
AsArg
().
name
;
if
(
op_type
==
"subgraph"
)
{
found
=
InferScaleFromSubgraph
(
var_name
,
op_info
,
scale
,
false
);
}
else
{
if
(
op_info
->
HasAttr
(
"input_scale"
))
{
*
scale
=
op_info
->
GetAttr
<
float
>
(
"input_scale"
);
found
=
true
;
}
else
{
// Obtain the output_scale from one of its previous Ops
auto
prev_op_node
=
var_node
->
inlinks
.
front
();
CHECK
(
prev_op_node
->
IsStmt
());
auto
&
prev_inst
=
prev_op_node
->
AsStmt
();
auto
prev_op_info
=
prev_inst
.
op_info
();
auto
prev_op_type
=
prev_op_info
->
Type
();
if
(
prev_op_type
==
"subgraph"
)
{
found
=
InferScaleFromSubgraph
(
var_name
,
prev_op_info
,
scale
,
true
);
}
else
{
if
(
prev_op_info
->
HasAttr
(
"output_scale"
))
{
*
scale
=
prev_op_info
->
GetAttr
<
float
>
(
"output_scale"
);
found
=
true
;
}
}
}
}
return
found
;
}
void
PrecisionCastPass
::
Apply
(
const
std
::
unique_ptr
<
SSAGraph
>&
graph
)
{
// Start from inputs of the graph, those should have place set.
std
::
list
<
Node
*>
nodes
;
...
...
@@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArg
().
type
);
VLOG
(
4
)
<<
inst
.
picked_kernel
().
name
();
if
(
inst
.
op_info
()
->
Type
()
==
"fetch"
)
{
if
(
inst
.
op_info
()
->
HasAttr
(
"data_type"
))
{
auto
data_type
=
static_cast
<
PrecisionType
>
(
inst
.
op_info
()
->
GetAttr
<
int
>
(
"data_type"
));
decl_arg_type
=
LiteType
::
GetTensorTy
(
decl_arg_type
->
target
(),
data_type
,
decl_arg_type
->
layout
());
}
}
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// *decl_arg_type)) {
if
(
!
PrecisionCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
...
...
@@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from,
op_desc
.
SetType
(
cast_type
);
op_desc
.
SetInput
(
"Input"
,
{
in
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_op_output_name
});
if
(
inst_node
->
AsStmt
().
op_info
()
->
HasAttr
(
"input_scale"
))
{
op_desc
.
SetAttr
(
"scale"
,
inst_node
->
AsStmt
().
op_info
()
->
GetAttr
<
float
>
(
"input_scale"
)
);
float
scale
;
if
(
InferScale
(
in
,
inst_node
,
&
scale
))
{
op_desc
.
SetAttr
(
"scale"
,
scale
);
}
cast_op
->
Attach
(
op_desc
,
inst_node
->
AsStmt
().
op
()
->
scope
());
auto
kernels
=
cast_op
->
CreateKernels
(
valid_places
);
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
selected_kernels
;
...
...
@@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
DirectedLink
(
cast_op_output_arg
,
inst_node
);
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
in
->
AsArg
().
name
,
cast_op_output_name
);
UpdateInputs
(
inst_node
->
AsStmt
().
op
().
get
(),
in
->
AsArg
().
name
,
cast_op_output_name
);
// recreate the op
auto
original_selected_kernel
=
...
...
lite/core/mir/type_precision_cast_pass.h
浏览文件 @
08a3ed12
...
...
@@ -24,17 +24,7 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
}
}
}
void
UpdateInputs
(
OpLite
*
op
,
const
std
::
string
&
from
,
const
std
::
string
&
to
);
/*
* The pass complement the necessary instruction to make data
...
...
lite/core/mir/type_target_cast_pass.cc
浏览文件 @
08a3ed12
...
...
@@ -21,6 +21,7 @@
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/utils/string.h"
namespace
paddle
{
...
...
@@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in,
Node
*
inst_node
,
std
::
string
io_copy_output_name
)
{
// reset opdesc and update kernel information
UpdateInputTo
(
inst_node
->
AsStmt
().
op
()
->
mutable_op_info
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
UpdateInputs
(
inst_node
->
AsStmt
().
op
().
get
(),
in
->
AsArg
().
name
,
io_copy_output_name
);
auto
original_selected_kernel
=
std
::
move
(
inst_node
->
AsStmt
().
kernels
().
front
());
auto
update_op_info
=
*
inst_node
->
AsStmt
().
op_info
();
...
...
lite/core/mir/type_target_cast_pass.h
浏览文件 @
08a3ed12
...
...
@@ -25,18 +25,6 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
static
void
UpdateInputTo
(
cpp
::
OpDesc
*
desc
,
const
std
::
string
&
from
,
const
std
::
string
&
to
)
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
item
.
second
)
{
if
(
input
==
from
)
{
input
=
to
;
}
}
}
}
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
...
...
lite/core/op_registry.cc
浏览文件 @
08a3ed12
...
...
@@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry()
INIT_FOR
(
kX86
,
kInt64
,
kNCHW
);
INIT_FOR
(
kARM
,
kFloat
,
kNCHW
);
INIT_FOR
(
kARM
,
kFloat
,
kNHWC
);
INIT_FOR
(
kARM
,
kInt8
,
kNCHW
);
INIT_FOR
(
kARM
,
kInt8
,
kNHWC
);
INIT_FOR
(
kARM
,
kAny
,
kNCHW
);
INIT_FOR
(
kARM
,
kAny
,
kAny
);
INIT_FOR
(
kARM
,
kInt32
,
kNCHW
);
...
...
@@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry()
INIT_FOR
(
kOpenCL
,
kAny
,
kImageNW
);
INIT_FOR
(
kNPU
,
kFloat
,
kNCHW
);
INIT_FOR
(
kNPU
,
kFloat
,
kNHWC
);
INIT_FOR
(
kNPU
,
kInt8
,
kNCHW
);
INIT_FOR
(
kNPU
,
kInt8
,
kNHWC
);
INIT_FOR
(
kNPU
,
kAny
,
kNCHW
);
INIT_FOR
(
kNPU
,
kAny
,
kNHWC
);
INIT_FOR
(
kNPU
,
kAny
,
kAny
);
INIT_FOR
(
kXPU
,
kFloat
,
kNCHW
);
...
...
lite/core/optimizer.h
浏览文件 @
08a3ed12
...
...
@@ -75,6 +75,12 @@ class Optimizer {
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass"
,
//
#endif
"quantized_op_attributes_inference_pass"
,
// Only for fully
// quantized model, infer
// the output scale and
// fix the attribute
// 'enable_int8' for all
// of the quantized ops.
"npu_subgraph_pass"
,
"xpu_subgraph_pass"
,
"bm_subgraph_pass"
,
...
...
lite/core/tensor.cc
浏览文件 @
08a3ed12
...
...
@@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) {
target_
=
other
.
target_
;
lod_
=
other
.
lod_
;
memory_size_
=
other
.
memory_size_
;
precision_
=
other
.
precision_
;
}
void
TensorLite
::
CopyDataFrom
(
const
TensorLite
&
other
)
{
...
...
@@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_
=
other
.
target_
;
lod_
=
other
.
lod_
;
memory_size_
=
other
.
memory_size_
;
precision_
=
other
.
precision
()
;
precision_
=
other
.
precision
_
;
buffer_
->
CopyDataFrom
(
*
other
.
buffer_
,
memory_size_
);
}
...
...
lite/kernels/arm/calib_compute.cc
浏览文件 @
08a3ed12
...
...
@@ -23,24 +23,24 @@ namespace lite {
namespace
kernels
{
namespace
arm
{
void
CalibComputeFp32ToInt8
::
Run
()
{
auto
&
param
=
this
->
Param
<
operators
::
CalibParam
>
();
template
<
DataLayoutType
DLType
>
void
CalibComputeFp32ToInt8
<
DLType
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
operators
::
CalibParam
>();
std
::
vector
<
float
>
scale
=
{
param
.
scale
};
const
auto
*
din
=
param
.
input
->
data
<
float
>
();
auto
*
dout
=
param
.
output
->
mutable_data
<
signed
char
>
();
const
auto
*
din
=
param
.
input
->
template
data
<
float
>();
auto
*
dout
=
param
.
output
->
template
mutable_data
<
signed
char
>();
lite
::
arm
::
math
::
fp32_to_int8
(
din
,
dout
,
scale
.
data
(),
1
,
1
,
param
.
input
->
numel
());
return
;
}
void
CalibComputeInt8ToFp32
::
Run
()
{
auto
&
param
=
this
->
Param
<
operators
::
CalibParam
>
();
const
auto
*
din
=
param
.
input
->
data
<
signed
char
>
();
template
<
DataLayoutType
DLType
>
void
CalibComputeInt8ToFp32
<
DLType
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
operators
::
CalibParam
>();
const
auto
*
din
=
param
.
input
->
template
data
<
signed
char
>();
std
::
vector
<
float
>
scale
=
{
param
.
scale
};
auto
*
dout
=
param
.
output
->
mutable_data
<
float
>
();
auto
*
dout
=
param
.
output
->
template
mutable_data
<
float
>();
lite
::
arm
::
math
::
int8_to_fp32
(
din
,
dout
,
scale
.
data
(),
1
,
1
,
param
.
input
->
numel
());
return
;
}
}
// namespace arm
...
...
@@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
,
fp32_to_int8
)
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNCHW
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
,
int8_to_fp32
)
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNCHW
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
,
fp32_to_int8
)
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNHWC
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNHWC
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNCHW
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
,
int8_to_fp32
)
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNCHW
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNCHW
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeFp32ToInt8
<
DATALAYOUT
(
kNHWC
)
>
,
fp32_to_int8
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
REGISTER_LITE_KERNEL
(
calib_once
,
kARM
,
kInt8
,
kNHWC
,
paddle
::
lite
::
kernels
::
arm
::
CalibComputeInt8ToFp32
<
DATALAYOUT
(
kNHWC
)
>
,
int8_to_fp32
)
.
BindInput
(
"Input"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kInt8
),
DATALAYOUT
(
kNHWC
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kARM
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNHWC
))})
.
Finalize
();
lite/kernels/arm/calib_compute.h
浏览文件 @
08a3ed12
...
...
@@ -21,8 +21,9 @@ namespace lite {
namespace
kernels
{
namespace
arm
{
template
<
DataLayoutType
DLType
>
class
CalibComputeFp32ToInt8
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
>
{
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
,
DLType
>
{
public:
using
param_t
=
operators
::
CalibParam
;
...
...
@@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8
private:
};
template
<
DataLayoutType
DLType
>
class
CalibComputeInt8ToFp32
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
>
{
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
,
DLType
>
{
public:
using
param_t
=
operators
::
CalibParam
;
...
...
lite/kernels/arm/layout_compute.cc
浏览文件 @
08a3ed12
...
...
@@ -20,40 +20,50 @@ namespace lite {
namespace
kernels
{
namespace
arm
{
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
<< "NCHW to NHWC should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
int c = input_dim[1]; \
int h = input_dim[2]; \
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
if (input_dim.size() != 4) { \
LOG(WARNING) << "NCHW to NHWC should guarantee that the input dims " \
"should be 4, but received " \
<< input_dim.size(); \
param.y->ShareDataWith(*param.x); \
return; \
} \
int n = input_dim[0]; \
int c = input_dim[1]; \
int h = input_dim[2]; \
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output);
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
<< "NHWC to NCHW should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
int h = input_dim[1]; \
int w = input_dim[2]; \
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
if (input_dim.size() != 4) { \
LOG(WARNING) << "NHWC to NCHW should guarantee that the input dims " \
"should be 4, but received " \
<< input_dim.size(); \
param.y->ShareDataWith(*param.x); \
return; \
} \
int n = input_dim[0]; \
int h = input_dim[1]; \
int w = input_dim[2]; \
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output);
template
<
>
...
...
lite/kernels/host/feed_compute.cc
浏览文件 @
08a3ed12
...
...
@@ -20,8 +20,7 @@ namespace lite {
namespace
kernels
{
namespace
host
{
class
FeedCompute
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
class
FeedCompute
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
...
...
@@ -40,7 +39,7 @@ class FeedCompute
}
// namespace paddle
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
k
Any
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
))})
feed
,
kHost
,
kAny
,
k
NCHW
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
)
,
PRECISION
(
kAny
)
)})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
)
,
PRECISION
(
kAny
)
)})
.
Finalize
();
lite/kernels/host/fetch_compute.cc
浏览文件 @
08a3ed12
...
...
@@ -20,8 +20,7 @@ namespace lite {
namespace
kernels
{
namespace
host
{
class
FetchCompute
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)
>
{
class
FetchCompute
:
public
KernelLite
<
TARGET
(
kHost
),
PRECISION
(
kAny
)
>
{
public:
using
param_t
=
operators
::
FeedParam
;
...
...
@@ -43,11 +42,7 @@ class FetchCompute
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
-
1
)})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
-
1
)})
fetch
,
kHost
,
kAny
,
kNCHW
,
paddle
::
lite
::
kernels
::
host
::
FetchCompute
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
))})
.
BindOutput
(
"Out"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kAny
))})
.
Finalize
();
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录