Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
9c6b00a2
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9c6b00a2
编写于
3月 18, 2020
作者:
J
jackzhang235
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support first conv, test=develop
上级
273e6fae
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
218 addition
and
24 deletion
+218
-24
lite/api/cxx_api_impl.cc
lite/api/cxx_api_impl.cc
+8
-3
lite/api/paddle_api.h
lite/api/paddle_api.h
+3
-0
lite/core/device_info.cc
lite/core/device_info.cc
+17
-1
lite/core/device_info.h
lite/core/device_info.h
+11
-1
lite/core/mir/mlu_postprocess_pass.cc
lite/core/mir/mlu_postprocess_pass.cc
+52
-2
lite/core/mir/mlu_postprocess_pass.h
lite/core/mir/mlu_postprocess_pass.h
+10
-0
lite/kernels/mlu/bridges/conv_op.cc
lite/kernels/mlu/bridges/conv_op.cc
+68
-17
lite/kernels/mlu/bridges/graph.h
lite/kernels/mlu/bridges/graph.h
+47
-0
lite/kernels/mlu/bridges/tensor.h
lite/kernels/mlu/bridges/tensor.h
+2
-0
未找到文件。
lite/api/cxx_api_impl.cc
浏览文件 @
9c6b00a2
...
@@ -38,6 +38,14 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
...
@@ -38,6 +38,14 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
Env
<
TARGET
(
kMLU
)
>::
Init
();
Env
<
TARGET
(
kMLU
)
>::
Init
();
mlu_core_version_
=
config
.
mlu_core_version
();
mlu_core_version_
=
config
.
mlu_core_version
();
mlu_core_number_
=
config
.
mlu_core_number
();
mlu_core_number_
=
config
.
mlu_core_number
();
use_first_conv_
=
config
.
use_first_conv
();
mean_vec_
=
config
.
mean
();
std_vec_
=
config
.
std
();
lite
::
DeviceInfo
::
Global
().
SetMLURunMode
(
mlu_core_version_
,
mlu_core_number_
,
use_first_conv_
,
mean_vec_
,
std_vec_
);
#endif // LITE_WITH_MLU
#endif // LITE_WITH_MLU
auto
places
=
config
.
valid_places
();
auto
places
=
config
.
valid_places
();
std
::
vector
<
std
::
string
>
passes
{};
std
::
vector
<
std
::
string
>
passes
{};
...
@@ -87,9 +95,6 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
...
@@ -87,9 +95,6 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
void
CxxPaddleApiImpl
::
Run
()
{
void
CxxPaddleApiImpl
::
Run
()
{
#ifdef LITE_WITH_ARM
#ifdef LITE_WITH_ARM
lite
::
DeviceInfo
::
Global
().
SetRunMode
(
mode_
,
threads_
);
lite
::
DeviceInfo
::
Global
().
SetRunMode
(
mode_
,
threads_
);
#endif
#ifdef LITE_WITH_MLU
lite
::
DeviceInfo
::
Global
().
SetMLURunMode
(
mlu_core_version_
,
mlu_core_number_
);
#endif
#endif
raw_predictor_
.
Run
();
raw_predictor_
.
Run
();
}
}
...
...
lite/api/paddle_api.h
浏览文件 @
9c6b00a2
...
@@ -108,6 +108,9 @@ class LITE_API PaddlePredictor {
...
@@ -108,6 +108,9 @@ class LITE_API PaddlePredictor {
lite_api
::
PowerMode
mode_
{
lite_api
::
LITE_POWER_NO_BIND
};
lite_api
::
PowerMode
mode_
{
lite_api
::
LITE_POWER_NO_BIND
};
lite_api
::
MLUCoreVersion
mlu_core_version_
{
lite_api
::
MLU_270
};
lite_api
::
MLUCoreVersion
mlu_core_version_
{
lite_api
::
MLU_270
};
int
mlu_core_number_
{
1
};
int
mlu_core_number_
{
1
};
bool
use_first_conv_
{
false
};
std
::
vector
<
float
>
mean_vec_
;
std
::
vector
<
float
>
std_vec_
;
};
};
/// Base class for all the configs.
/// Base class for all the configs.
...
...
lite/core/device_info.cc
浏览文件 @
9c6b00a2
...
@@ -69,6 +69,9 @@ thread_local int64_t DeviceInfo::count_ = 0;
...
@@ -69,6 +69,9 @@ thread_local int64_t DeviceInfo::count_ = 0;
#ifdef LITE_WITH_MLU
#ifdef LITE_WITH_MLU
thread_local
cnmlCoreVersion_t
DeviceInfo
::
mlu_core_version_
{
CNML_MLU270
};
thread_local
cnmlCoreVersion_t
DeviceInfo
::
mlu_core_version_
{
CNML_MLU270
};
thread_local
int
DeviceInfo
::
mlu_core_number_
{
1
};
thread_local
int
DeviceInfo
::
mlu_core_number_
{
1
};
thread_local
bool
DeviceInfo
::
use_first_conv_
{
false
};
thread_local
std
::
vector
<
float
>
DeviceInfo
::
mean_vec_
;
thread_local
std
::
vector
<
float
>
DeviceInfo
::
std_vec_
;
#endif
#endif
#ifdef TARGET_IOS
#ifdef TARGET_IOS
...
@@ -1087,7 +1090,10 @@ int DeviceInfo::Setup() {
...
@@ -1087,7 +1090,10 @@ int DeviceInfo::Setup() {
#ifdef LITE_WITH_MLU
#ifdef LITE_WITH_MLU
void
DeviceInfo
::
SetMLURunMode
(
lite_api
::
MLUCoreVersion
core_version
,
void
DeviceInfo
::
SetMLURunMode
(
lite_api
::
MLUCoreVersion
core_version
,
int
core_number
)
{
int
core_number
,
bool
use_first_conv
,
const
std
::
vector
<
float
>&
mean_vec
,
const
std
::
vector
<
float
>&
std_vec
)
{
switch
(
core_version
)
{
switch
(
core_version
)
{
case
(
lite_api
::
MLUCoreVersion
::
MLU_220
):
case
(
lite_api
::
MLUCoreVersion
::
MLU_220
):
mlu_core_version_
=
CNML_MLU220
;
mlu_core_version_
=
CNML_MLU220
;
...
@@ -1100,11 +1106,21 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
...
@@ -1100,11 +1106,21 @@ void DeviceInfo::SetMLURunMode(lite_api::MLUCoreVersion core_version,
break
;
break
;
}
}
mlu_core_number_
=
core_number
;
mlu_core_number_
=
core_number
;
use_first_conv_
=
use_first_conv
;
mean_vec_
=
mean_vec
;
std_vec_
=
std_vec
;
}
}
cnmlCoreVersion_t
DeviceInfo
::
MLUCoreVersion
()
{
return
mlu_core_version_
;
}
cnmlCoreVersion_t
DeviceInfo
::
MLUCoreVersion
()
{
return
mlu_core_version_
;
}
int
DeviceInfo
::
MLUCoreNumber
()
{
return
mlu_core_number_
;
}
int
DeviceInfo
::
MLUCoreNumber
()
{
return
mlu_core_number_
;
}
bool
DeviceInfo
::
UseFirstConv
()
{
return
use_first_conv_
;
}
const
std
::
vector
<
float
>&
DeviceInfo
::
MeanVec
()
const
{
return
mean_vec_
;
}
const
std
::
vector
<
float
>&
DeviceInfo
::
StdVec
()
const
{
return
std_vec_
;
}
#endif // LITE_WITH_MLU
#endif // LITE_WITH_MLU
void
DeviceInfo
::
SetRunMode
(
lite_api
::
PowerMode
mode
,
int
thread_num
)
{
void
DeviceInfo
::
SetRunMode
(
lite_api
::
PowerMode
mode
,
int
thread_num
)
{
...
...
lite/core/device_info.h
浏览文件 @
9c6b00a2
...
@@ -56,9 +56,16 @@ class DeviceInfo {
...
@@ -56,9 +56,16 @@ class DeviceInfo {
void
SetRunMode
(
lite_api
::
PowerMode
mode
,
int
thread_num
);
void
SetRunMode
(
lite_api
::
PowerMode
mode
,
int
thread_num
);
#ifdef LITE_WITH_MLU
#ifdef LITE_WITH_MLU
void
SetMLURunMode
(
lite_api
::
MLUCoreVersion
core_version
,
int
core_number
);
void
SetMLURunMode
(
lite_api
::
MLUCoreVersion
core_version
,
int
core_number
,
bool
use_first_conv
,
const
std
::
vector
<
float
>&
mean_vec
,
const
std
::
vector
<
float
>&
std_vec
);
cnmlCoreVersion_t
MLUCoreVersion
();
cnmlCoreVersion_t
MLUCoreVersion
();
int
MLUCoreNumber
();
int
MLUCoreNumber
();
bool
UseFirstConv
();
const
std
::
vector
<
float
>&
MeanVec
()
const
;
const
std
::
vector
<
float
>&
StdVec
()
const
;
#endif
#endif
void
SetCache
(
int
l1size
,
int
l2size
,
int
l3size
);
void
SetCache
(
int
l1size
,
int
l2size
,
int
l3size
);
void
SetArch
(
ARMArch
arch
)
{
arch_
=
arch
;
}
void
SetArch
(
ARMArch
arch
)
{
arch_
=
arch
;
}
...
@@ -114,6 +121,9 @@ class DeviceInfo {
...
@@ -114,6 +121,9 @@ class DeviceInfo {
#ifdef LITE_WITH_MLU
#ifdef LITE_WITH_MLU
static
thread_local
cnmlCoreVersion_t
mlu_core_version_
;
static
thread_local
cnmlCoreVersion_t
mlu_core_version_
;
static
thread_local
int
mlu_core_number_
;
static
thread_local
int
mlu_core_number_
;
static
thread_local
bool
use_first_conv_
;
static
thread_local
std
::
vector
<
float
>
mean_vec_
;
static
thread_local
std
::
vector
<
float
>
std_vec_
;
#endif
#endif
void
SetDotInfo
(
int
argc
,
...);
void
SetDotInfo
(
int
argc
,
...);
...
...
lite/core/mir/mlu_postprocess_pass.cc
浏览文件 @
9c6b00a2
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
#include "lite/core/mir/mlu_postprocess_pass.h"
#include "lite/core/mir/mlu_postprocess_pass.h"
#include <list>
#include <list>
#include <memory>
#include <memory>
#include <set>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
...
@@ -211,6 +210,10 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
...
@@ -211,6 +210,10 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
auto
*
cur_node
=
head_node
;
auto
*
cur_node
=
head_node
;
const
auto
name_prefix
=
const
auto
name_prefix
=
head_node
->
AsArg
().
name
+
string_format
(
"_%p"
,
inst_node
)
+
"/trans_"
;
head_node
->
AsArg
().
name
+
string_format
(
"_%p"
,
inst_node
)
+
"/trans_"
;
bool
is_first_conv_head
=
std
::
find
(
first_conv_nodes_
.
begin
(),
first_conv_nodes_
.
end
(),
head_node
->
AsArg
().
name
)
!=
first_conv_nodes_
.
end
();
// layout cast node
// layout cast node
if
(
head_type
->
layout
()
!=
inst_type
->
layout
())
{
if
(
head_type
->
layout
()
!=
inst_type
->
layout
())
{
...
@@ -225,7 +228,7 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
...
@@ -225,7 +228,7 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
}
}
// precision cast node
// precision cast node
if
(
head_type
->
precision
()
!=
inst_type
->
precision
())
{
if
(
head_type
->
precision
()
!=
inst_type
->
precision
()
&&
!
is_first_conv_head
)
{
cur_node
=
InsertCastBefore
(
cur_node
=
InsertCastBefore
(
"cast"
,
"cast"
,
name_prefix
+
"cast"
,
name_prefix
+
"cast"
,
...
@@ -433,6 +436,49 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) {
...
@@ -433,6 +436,49 @@ void MLUPostprocessPass::RecreateOp(Node* inst_node, SSAGraph* graph) {
}
}
}
}
bool
MLUPostprocessPass
::
IsFirstConvInSubgraph
(
Node
*
arg_node
,
Node
*
inst
)
{
auto
*
block_desc
=
static_cast
<
operators
::
SubgraphOp
*>
(
inst
->
AsStmt
().
op
().
get
())
->
GetSubBlock
();
for
(
int
op_idx
=
0
;
op_idx
<
block_desc
->
OpsSize
();
op_idx
++
)
{
auto
op_desc
=
block_desc
->
GetOp
<
cpp
::
OpDesc
>
(
op_idx
);
CHECK
(
op_desc
);
if
(
op_desc
->
Type
()
==
"conv2d"
)
{
for
(
auto
&
names
:
op_desc
->
inputs
())
{
if
(
std
::
find
(
names
.
second
.
begin
(),
names
.
second
.
end
(),
arg_node
->
AsArg
().
name
)
!=
names
.
second
.
end
())
{
return
true
;
}
}
}
}
return
false
;
}
bool
MLUPostprocessPass
::
IsFirstConvNode
(
Node
*
arg_node
)
{
CHECK
(
arg_node
->
IsArg
());
for
(
auto
&
inst
:
arg_node
->
outlinks
)
{
if
(
inst
->
AsStmt
().
op_type
()
==
"subgraph"
)
{
return
IsFirstConvInSubgraph
(
arg_node
,
inst
);
}
}
return
false
;
}
void
MLUPostprocessPass
::
GatherFirstConvNodes
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsStmt
())
continue
;
if
(
node
.
AsStmt
().
op_type
()
==
"feed"
)
{
for
(
auto
&
out
:
node
.
outlinks
)
{
if
(
IsFirstConvNode
(
out
))
{
first_conv_nodes_
.
insert
(
out
->
AsArg
().
name
);
}
}
}
}
}
void
MLUPostprocessPass
::
ModifyLayout
(
SSAGraph
*
graph
)
{
void
MLUPostprocessPass
::
ModifyLayout
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsStmt
())
continue
;
if
(
!
node
.
IsStmt
())
continue
;
...
@@ -487,6 +533,10 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -487,6 +533,10 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Thus here we change these args' layout to NHWC
// Thus here we change these args' layout to NHWC
ModifyLayout
(
graph
.
get
());
ModifyLayout
(
graph
.
get
());
if
(
lite
::
DeviceInfo
::
Global
().
UseFirstConv
())
{
GatherFirstConvNodes
(
graph
.
get
());
}
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
// insert io_copy, layout and precision cast of subgraph's inputs and outputs
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsStmt
()
&&
node
.
AsStmt
().
op_type
()
==
"subgraph"
)
{
if
(
node
.
IsStmt
()
&&
node
.
AsStmt
().
op_type
()
==
"subgraph"
)
{
...
...
lite/core/mir/mlu_postprocess_pass.h
浏览文件 @
9c6b00a2
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include <memory>
#include <memory>
#include <set>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass.h"
...
@@ -107,6 +108,15 @@ class MLUPostprocessPass : public ProgramPass {
...
@@ -107,6 +108,15 @@ class MLUPostprocessPass : public ProgramPass {
const
Type
*
cast_type
);
const
Type
*
cast_type
);
void
RecreateOp
(
Node
*
inst_node
,
SSAGraph
*
graph
);
void
RecreateOp
(
Node
*
inst_node
,
SSAGraph
*
graph
);
void
GatherFirstConvNodes
(
SSAGraph
*
graph
);
bool
IsFirstConvNode
(
Node
*
arg_node
);
bool
IsFirstConvInSubgraph
(
Node
*
arg_node
,
Node
*
inst
);
private:
std
::
set
<
std
::
string
>
first_conv_nodes_
;
};
};
}
// namespace mir
}
// namespace mir
...
...
lite/kernels/mlu/bridges/conv_op.cc
浏览文件 @
9c6b00a2
...
@@ -119,14 +119,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
...
@@ -119,14 +119,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
LOG
(
FATAL
)
<<
"UnSupported weight precision!"
;
LOG
(
FATAL
)
<<
"UnSupported weight precision!"
;
}
}
cnmlConvOpParam_t
conv_param
;
CNML_CALL
(
cnmlCreateConvOpParam
(
&
conv_param
,
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
paddings
[
0
]
*
2
,
paddings
[
2
]
*
2
));
std
::
string
bias_var_name
;
std
::
string
bias_var_name
;
std
::
shared_ptr
<
MLUTensor
>
bias_tensor
;
std
::
shared_ptr
<
MLUTensor
>
bias_tensor
;
if
(
HasInputArg
(
op_info
,
scope
,
"Bias"
))
{
if
(
HasInputArg
(
op_info
,
scope
,
"Bias"
))
{
...
@@ -160,15 +152,75 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
...
@@ -160,15 +152,75 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
graph
->
FPType
());
graph
->
FPType
());
graph
->
BindConstData
(
bias_var_name
,
bias
);
graph
->
BindConstData
(
bias_var_name
,
bias
);
}
}
cnmlBaseOp_t
conv_op
;
const
auto
input_scale
=
op_info
->
GetAttr
<
float
>
(
"input_scale"
);
const
auto
input_scale
=
op_info
->
GetAttr
<
float
>
(
"input_scale"
);
CNML_CALL
(
cnmlCreateConvOpForward
(
&
conv_op
,
bool
use_first_conv
=
false
;
conv_param
,
if
(
lite
::
DeviceInfo
::
Global
().
UseFirstConv
()
&&
input_dims_nhwc
[
3
]
==
3
)
{
graph
->
GetNode
(
input_var_name
)
->
mlu_tensor
(),
use_first_conv
=
true
;
output_tensor
->
mlu_tensor
(),
}
filter_tensor
->
mlu_tensor
(),
bias_tensor
?
bias_tensor
->
mlu_tensor
()
:
nullptr
));
cnmlBaseOp_t
conv_op
;
if
(
use_first_conv
)
{
cnmlConvFirstOpParam_t
conv_param
;
CNML_CALL
(
cnmlCreateConvFirstOpParam_V2
(
&
conv_param
,
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
paddings
[
2
],
paddings
[
2
],
paddings
[
0
],
paddings
[
0
]));
const
auto
mean_tensor
=
graph
->
AddNode
(
"first_conv_mean_tensor"
,
std
::
vector
<
int64_t
>
{
3
},
CNML_CONST
,
CNML_CNHW
,
graph
->
FPType
());
const
auto
std_tensor
=
graph
->
AddNode
(
"first_conv_std_tensor"
,
std
::
vector
<
int64_t
>
{
3
},
CNML_CONST
,
CNML_CNHW
,
graph
->
FPType
());
graph
->
BindConstRawData
(
"first_conv_mean_tensor"
,
lite
::
DeviceInfo
::
Global
().
MeanVec
().
data
(),
3
,
false
);
graph
->
BindConstRawData
(
"first_conv_std_tensor"
,
lite
::
DeviceInfo
::
Global
().
StdVec
().
data
(),
3
,
false
);
graph
->
GetNode
(
input_var_name
)
->
set_mlu_dtype
(
CNML_DATA_UINT8
);
CNML_CALL
(
cnmlCreateConvFirstOpForward
(
&
conv_op
,
conv_param
,
graph
->
GetNode
(
input_var_name
)
->
mlu_tensor
(),
mean_tensor
->
mlu_tensor
(),
output_tensor
->
mlu_tensor
(),
filter_tensor
->
mlu_tensor
(),
bias_tensor
?
bias_tensor
->
mlu_tensor
()
:
nullptr
,
std_tensor
->
mlu_tensor
()));
CNML_CALL
(
cnmlDestroyConvFirstOpParam
(
&
conv_param
));
}
else
{
cnmlConvOpParam_t
conv_param
;
CNML_CALL
(
cnmlCreateConvOpParam
(
&
conv_param
,
strides
[
0
],
strides
[
1
],
dilations
[
0
],
dilations
[
1
],
paddings
[
0
]
*
2
,
paddings
[
2
]
*
2
));
CNML_CALL
(
cnmlCreateConvOpForward
(
&
conv_op
,
conv_param
,
graph
->
GetNode
(
input_var_name
)
->
mlu_tensor
(),
output_tensor
->
mlu_tensor
(),
filter_tensor
->
mlu_tensor
(),
bias_tensor
?
bias_tensor
->
mlu_tensor
()
:
nullptr
));
CNML_CALL
(
cnmlDestroyConvOpParam
(
&
conv_param
));
}
graph
->
SetComputingDataType
(
graph
->
SetComputingDataType
(
conv_op
,
graph
->
GetNode
(
input_var_name
)
->
mlu_tensor
(),
1
/
input_scale
);
conv_op
,
graph
->
GetNode
(
input_var_name
)
->
mlu_tensor
(),
1
/
input_scale
);
...
@@ -183,7 +235,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
...
@@ -183,7 +235,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
}
graph
->
BindConstData
(
filter_var_name
,
filter
);
graph
->
BindConstData
(
filter_var_name
,
filter
);
graph
->
FuseOp
(
conv_op
);
graph
->
FuseOp
(
conv_op
);
CNML_CALL
(
cnmlDestroyConvOpParam
(
&
conv_param
));
return
REBUILD_WHEN_SHAPE_CHANGED
;
return
REBUILD_WHEN_SHAPE_CHANGED
;
}
}
...
...
lite/kernels/mlu/bridges/graph.h
浏览文件 @
9c6b00a2
...
@@ -180,6 +180,53 @@ class Graph {
...
@@ -180,6 +180,53 @@ class Graph {
}
}
}
}
template
<
typename
T
>
void
*
RegisterConstData
(
size_t
len
)
{
void
*
addr
=
malloc
(
len
*
sizeof
(
T
));
const_data_storage_
.
push_back
(
addr
);
return
addr
;
}
void
FreeConstData
()
{
for
(
auto
&
addr
:
const_data_storage_
)
{
free
(
addr
);
}
}
void
BindConstRawData
(
std
::
string
tensor_name
,
const
float
*
data
,
size_t
len
,
bool
alloc
=
true
)
{
void
*
alloc_data
;
if
(
fp_type_
==
CNML_DATA_FLOAT32
)
{
if
(
alloc
)
{
alloc_data
=
RegisterConstData
<
float
>
(
len
);
memcpy
(
alloc_data
,
data
,
len
*
sizeof
(
float
));
}
else
{
alloc_data
=
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
data
));
}
CNML_CALL
(
cnmlBindConstData_V2
(
nodes_
[
tensor_name
]
->
mlu_tensor
(),
alloc_data
,
false
));
}
else
if
(
fp_type_
==
CNML_DATA_FLOAT16
)
{
void
*
data_fp16
=
RegisterConstData
<::
paddle
::
lite
::
fluid
::
float16
>
(
len
);
// for (size_t i = 0; i < len; ++i) {
// static_cast<::paddle::lite::fluid::float16*>(data_fp16)[i]
// = static_cast<::paddle::lite::fluid::float16>(data[i]);
// }
CNRT_CALL
(
cnrtCastDataType
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
data
)),
CNRT_FLOAT32
,
data_fp16
,
CNRT_FLOAT16
,
len
,
nullptr
));
CNML_CALL
(
cnmlBindConstData_V2
(
nodes_
[
tensor_name
]
->
mlu_tensor
(),
data_fp16
,
false
));
}
else
{
CHECK
(
0
);
}
}
void
BindConstData
(
std
::
string
tensor_name
,
::
paddle
::
lite
::
Tensor
*
tensor
)
{
void
BindConstData
(
std
::
string
tensor_name
,
::
paddle
::
lite
::
Tensor
*
tensor
)
{
const
float
*
data
=
tensor
->
data
<
float
>
();
const
float
*
data
=
tensor
->
data
<
float
>
();
size_t
len
=
tensor
->
data_size
();
size_t
len
=
tensor
->
data_size
();
...
...
lite/kernels/mlu/bridges/tensor.h
浏览文件 @
9c6b00a2
...
@@ -47,6 +47,8 @@ class MLUTensor {
...
@@ -47,6 +47,8 @@ class MLUTensor {
return
mlu_ptr_
;
return
mlu_ptr_
;
}
}
void
set_mlu_dtype
(
cnmlDataType_t
type
)
{
mlu_dtype_
=
type
;
}
~
MLUTensor
();
~
MLUTensor
();
private:
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录