Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f77de54a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f77de54a
编写于
4月 28, 2020
作者:
D
dinghao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix tensor dirty
上级
9c1a5db4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
14 addition
and
8 deletion
+14
-8
mindspore/ccsrc/ir/meta_tensor.cc
mindspore/ccsrc/ir/meta_tensor.cc
+3
-1
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+2
-1
tests/st/ops/gpu/test_assign_add_op.py
tests/st/ops/gpu/test_assign_add_op.py
+9
-6
未找到文件。
mindspore/ccsrc/ir/meta_tensor.cc
浏览文件 @
f77de54a
...
...
@@ -164,8 +164,9 @@ Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::arr
Tensor
::
Tensor
(
const
py
::
int_
&
input
,
const
TypePtr
&
data_type
)
{
init
(
py
::
array
(
input
),
data_type
);
}
Tensor
::
Tensor
(
const
Tensor
&
tensor
,
const
TypePtr
&
data_type
)
:
MetaTensor
(
tensor
),
d
irty_
(
tensor
.
dirty_
),
d
evice_address_
(
tensor
.
device_address_
)
{
:
MetaTensor
(
tensor
),
device_address_
(
tensor
.
device_address_
)
{
init
(
tensor
.
data_
,
data_type
);
dirty_
=
tensor
.
is_dirty
();
}
Tensor
&
Tensor
::
operator
=
(
const
Tensor
&
tensor
)
{
...
...
@@ -291,6 +292,7 @@ void Tensor::init(const py::array &input, const TypeId &data_type) {
}
else
{
data_
=
input
;
}
dirty_
=
true
;
}
void
Tensor
::
init
(
TypeId
data_type
,
const
std
::
vector
<
int
>
&
shape
,
py
::
array
*
const
data
)
{
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
f77de54a
...
...
@@ -127,6 +127,7 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
enable_pynative_infer
())
{
tensor
->
set_device_address
(
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
));
tensor
->
set_dirty
(
false
);
}
else
if
(
!
address
->
SyncDeviceToHost
(
trans
::
GetRuntimePaddingShape
(
node
,
output_index
),
LongToSize
(
tensor
->
data
().
nbytes
()),
tensor
->
data_type
(),
tensor
->
data_c
(
true
)))
{
...
...
@@ -491,7 +492,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
need_sync
=
true
;
}
}
else
{
if
(
tensor
->
is_dirty
()
||
!
AnfAlgo
::
IsParameterWeight
(
pk_node
)
)
{
if
(
tensor
->
is_dirty
())
{
need_sync
=
true
;
}
else
if
(
tensor
->
device_address
()
!=
device_address
)
{
(
void
)
tensor
->
data_sync
();
...
...
tests/st/ops/gpu/test_assign_add_op.py
浏览文件 @
f77de54a
...
...
@@ -51,19 +51,22 @@ def test_assign_add():
[[
54
,
57
,
60
],
[
63
,
66
,
69
],
[
72
,
75
,
78
]]]])
x
=
Tensor
(
np
.
arange
(
1
*
3
*
3
*
3
).
reshape
(
1
,
3
,
3
,
3
).
astype
(
np
.
float32
))
y
=
Tensor
(
np
.
arange
(
1
*
3
*
3
*
3
).
reshape
(
1
,
3
,
3
,
3
).
astype
(
np
.
float32
))
x1
=
Tensor
(
np
.
arange
(
1
*
3
*
3
*
3
).
reshape
(
1
,
3
,
3
,
3
).
astype
(
np
.
float32
))
y1
=
Tensor
(
np
.
arange
(
1
*
3
*
3
*
3
).
reshape
(
1
,
3
,
3
,
3
).
astype
(
np
.
float32
))
x2
=
Tensor
(
np
.
arange
(
1
*
3
*
3
*
3
).
reshape
(
1
,
3
,
3
,
3
).
astype
(
np
.
float32
))
y2
=
Tensor
(
np
.
arange
(
1
*
3
*
3
*
3
).
reshape
(
1
,
3
,
3
,
3
).
astype
(
np
.
float32
))
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
'GPU'
)
add
=
AssignAdd
()
output1
=
add
(
x
,
y
)
output1
=
add
(
x
1
,
y1
)
assert
(
output1
.
asnumpy
()
==
expect1
).
all
()
output2
=
add
(
output1
,
y
)
output2
=
add
(
output1
,
y
1
)
assert
(
output2
.
asnumpy
()
==
expect2
).
all
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
add
=
AssignAdd
()
output1
=
add
(
x
,
y
)
output1
=
add
(
x
2
,
y2
)
assert
(
output1
.
asnumpy
()
==
expect1
).
all
()
output2
=
add
(
output1
,
y
)
output2
=
add
(
output1
,
y
2
)
assert
(
output2
.
asnumpy
()
==
expect2
).
all
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录