Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1cb8d9da
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1cb8d9da
编写于
7月 31, 2020
作者:
C
chujinjin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize updateoutput in gpu
上级
6eddd65c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
20 addition
and
4 deletion
+20
-4
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+6
-1
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+12
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+2
-2
未找到文件。
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
1cb8d9da
...
@@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
...
@@ -426,7 +426,12 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
if
(
op_run_info
.
value
!=
nullptr
)
{
if
(
op_run_info
.
value
!=
nullptr
)
{
std
::
vector
<
tensor
::
TensorPtr
>
pre_output_tensors
;
std
::
vector
<
tensor
::
TensorPtr
>
pre_output_tensors
;
TensorValueToTensor
(
op_run_info
.
value
,
&
pre_output_tensors
);
TensorValueToTensor
(
op_run_info
.
value
,
&
pre_output_tensors
);
std
::
copy
(
pre_output_tensors
.
begin
(),
pre_output_tensors
.
end
(),
std
::
back_inserter
(
outputs
));
for
(
auto
&
pre_output
:
pre_output_tensors
)
{
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
pre_output
->
data_type
(),
pre_output
->
shape
());
tensor
->
set_device_address
(
pre_output
->
device_address
());
tensor
->
set_dirty
(
false
);
outputs
.
emplace_back
(
tensor
);
}
}
else
{
}
else
{
UpdateOutputs
(
graph
,
&
outputs
,
input_tensors
);
UpdateOutputs
(
graph
,
&
outputs
,
input_tensors
);
}
}
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
1cb8d9da
...
@@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
...
@@ -300,7 +300,18 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
}
}
// Fetch outputs
// Fetch outputs
VectorRef
outputs
;
VectorRef
outputs
;
UpdateOutputs
(
kernel_graph
,
&
outputs
,
input_tensors
);
if
(
op_run_info
.
value
!=
nullptr
)
{
std
::
vector
<
tensor
::
TensorPtr
>
pre_output_tensors
;
TensorValueToTensor
(
op_run_info
.
value
,
&
pre_output_tensors
);
for
(
auto
&
pre_output
:
pre_output_tensors
)
{
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
pre_output
->
data_type
(),
pre_output
->
shape
());
tensor
->
set_device_address
(
pre_output
->
device_address
());
tensor
->
set_dirty
(
false
);
outputs
.
emplace_back
(
tensor
);
}
}
else
{
UpdateOutputs
(
kernel_graph
,
&
outputs
,
input_tensors
);
}
// Trans output to tuple
// Trans output to tuple
auto
output_tensors
=
TransformBaseRefListToTuple
(
outputs
);
auto
output_tensors
=
TransformBaseRefListToTuple
(
outputs
);
if
(
!
utils
::
isa
<
PyObjectRef
>
(
output_tensors
)
||
if
(
!
utils
::
isa
<
PyObjectRef
>
(
output_tensors
)
||
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
1cb8d9da
...
@@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
...
@@ -565,9 +565,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if
(
session
==
nullptr
)
{
if
(
session
==
nullptr
)
{
session
=
session
::
SessionFactory
::
Get
().
Create
(
device_target
);
session
=
session
::
SessionFactory
::
Get
().
Create
(
device_target
);
MS_EXCEPTION_IF_NULL
(
session
);
session
->
Init
(
ms_context
->
device_id
());
}
}
MS_EXCEPTION_IF_NULL
(
session
);
session
->
Init
(
ms_context
->
device_id
());
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
;
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
;
std
::
vector
<
int
>
tensors_mask
;
std
::
vector
<
int
>
tensors_mask
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录