Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7304f024
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7304f024
编写于
7月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2902 move weight data copy to warmup stage
Merge pull request !2902 from dinghao/master
上级
dab7ad44
bd5a9edd
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
28 addition
and
12 deletion
+28
-12
mindspore/ccsrc/session/ascend_inference_session.cc
mindspore/ccsrc/session/ascend_inference_session.cc
+27
-12
mindspore/ccsrc/session/ascend_inference_session.h
mindspore/ccsrc/session/ascend_inference_session.h
+1
-0
未找到文件。
mindspore/ccsrc/session/ascend_inference_session.cc
浏览文件 @
7304f024
...
...
@@ -32,7 +32,6 @@ using mindspore::tensor::TensorPy;
namespace
mindspore
{
namespace
session
{
namespace
{
std
::
set
<
AnfNodePtr
>
weight_infos
;
static
TypeId
GetDataType
(
const
py
::
buffer_info
&
buf
)
{
if
(
buf
.
format
.
size
()
==
1
)
{
switch
(
buf
.
format
.
front
())
{
...
...
@@ -105,10 +104,33 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
MS_EXCEPTION_IF_NULL
(
pk_node
);
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
pk_node
,
0
);
MS_EXCEPTION_IF_NULL
(
device_address
);
if
(
AnfAlgo
::
IsParameterWeight
(
pk_node
))
{
if
(
weight_infos
.
count
(
pk_node
)
!=
0
)
{
continue
;
if
(
!
AnfAlgo
::
IsParameterWeight
(
pk_node
))
{
tensor
=
inputs
[
no_weight_input
++
];
if
(
!
device_address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
pk_node
,
0
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
()))
{
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
}
}
}
}
GraphId
AscendInferenceSession
::
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
{
auto
graph_id
=
AscendSession
::
CompileGraph
(
func_graph
);
auto
kernel_graph
=
GetGraph
(
graph_id
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
// load weight data to device
auto
input_nodes
=
kernel_graph
->
inputs
();
for
(
size_t
i
=
0
;
i
<
input_nodes
.
size
();
++
i
)
{
if
(
!
input_nodes
[
i
]
->
isa
<
Parameter
>
())
{
MS_LOG
(
ERROR
)
<<
"Kernel graph inputs have anfnode which is not Parameter"
;
continue
;
}
auto
pk_node
=
input_nodes
[
i
]
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
pk_node
);
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
pk_node
,
0
);
MS_EXCEPTION_IF_NULL
(
device_address
);
if
(
AnfAlgo
::
IsParameterWeight
(
pk_node
))
{
auto
param_value
=
std
::
dynamic_pointer_cast
<
ParamValuePy
>
(
pk_node
->
default_param
());
MS_EXCEPTION_IF_NULL
(
param_value
);
auto
py_param
=
param_value
->
value
();
...
...
@@ -120,16 +142,9 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
LongToSize
(
buf
.
size
*
buf
.
itemsize
),
buf_type
,
buf
.
ptr
))
{
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
}
weight_infos
.
insert
(
pk_node
);
}
else
{
tensor
=
inputs
[
no_weight_input
++
];
if
(
!
device_address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
pk_node
,
0
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
()))
{
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
}
}
}
return
graph_id
;
}
}
// namespace session
}
// namespace mindspore
mindspore/ccsrc/session/ascend_inference_session.h
浏览文件 @
7304f024
...
...
@@ -38,6 +38,7 @@ class AscendInferenceSession : public AscendSession {
~
AscendInferenceSession
()
=
default
;
void
LoadInputData
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs_const
)
const
;
GraphId
CompileGraph
(
NotNull
<
FuncGraphPtr
>
func_graph
)
override
;
};
MS_REG_SESSION
(
kDavinciInferenceDevice
,
AscendInferenceSession
);
}
// namespace session
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录