Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f42b3bbf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f42b3bbf
编写于
9月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5598 add tensor sync status
Merge pull request !5598 from kisnwang/async-run-graph
上级
529e1a0a
5614b2ba
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
37 addition
and
36 deletion
+37
-36
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+1
-1
mindspore/ccsrc/backend/session/executor.cc
mindspore/ccsrc/backend/session/executor.cc
+2
-2
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+3
-3
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+10
-9
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+1
-1
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
+5
-4
mindspore/ccsrc/runtime/device/kernel_adjust.cc
mindspore/ccsrc/runtime/device/kernel_adjust.cc
+2
-2
mindspore/core/ir/tensor.cc
mindspore/core/ir/tensor.cc
+3
-6
mindspore/core/ir/tensor.h
mindspore/core/ir/tensor.h
+10
-8
未找到文件。
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
f42b3bbf
...
@@ -410,7 +410,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
...
@@ -410,7 +410,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
for
(
auto
&
pre_output
:
pre_output_tensors
)
{
for
(
auto
&
pre_output
:
pre_output_tensors
)
{
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
pre_output
->
data_type
(),
pre_output
->
shape
());
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
pre_output
->
data_type
(),
pre_output
->
shape
());
tensor
->
set_device_address
(
pre_output
->
device_address
());
tensor
->
set_device_address
(
pre_output
->
device_address
());
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
outputs
->
emplace_back
(
tensor
);
outputs
->
emplace_back
(
tensor
);
}
}
}
else
{
}
else
{
...
...
mindspore/ccsrc/backend/session/executor.cc
浏览文件 @
f42b3bbf
...
@@ -38,9 +38,9 @@ void UpdateOutputTensors(VectorRef *outputs,
...
@@ -38,9 +38,9 @@ void UpdateOutputTensors(VectorRef *outputs,
auto
address
=
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
);
auto
address
=
AnfAlgo
::
GetMutableOutputAddr
(
node
,
output_index
);
tensor
->
set_device_address
(
address
);
tensor
->
set_device_address
(
address
);
}
}
if
(
tensor
->
need_sync
())
{
if
(
tensor
->
NeedSyncDeviceToHostImmediately
())
{
tensor
->
data_sync
();
tensor
->
data_sync
();
tensor
->
set_
need_sync
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
}
}
}
}
}
}
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
f42b3bbf
...
@@ -158,7 +158,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
...
@@ -158,7 +158,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if
(
tensor_address
==
nullptr
||
tensor_address
!=
device_address
)
{
if
(
tensor_address
==
nullptr
||
tensor_address
!=
device_address
)
{
need_sync
=
true
;
need_sync
=
true
;
}
}
}
else
if
(
tensor
->
is_dirty
()
||
tensor_address
==
nullptr
)
{
}
else
if
(
tensor
->
NeedSyncHostToDevice
()
||
tensor_address
==
nullptr
)
{
need_sync
=
true
;
need_sync
=
true
;
}
else
if
(
tensor_address
!=
device_address
)
{
}
else
if
(
tensor_address
!=
device_address
)
{
if
(
tensor_address
->
DeviceType
()
==
device_address
->
DeviceType
())
{
if
(
tensor_address
->
DeviceType
()
==
device_address
->
DeviceType
())
{
...
@@ -177,7 +177,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
...
@@ -177,7 +177,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
}
}
}
}
}
}
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
}
}
}
}
...
@@ -332,7 +332,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info
...
@@ -332,7 +332,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info
for
(
auto
&
pre_output
:
pre_output_tensors
)
{
for
(
auto
&
pre_output
:
pre_output_tensors
)
{
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
pre_output
->
data_type
(),
pre_output
->
shape
());
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
pre_output
->
data_type
(),
pre_output
->
shape
());
tensor
->
set_device_address
(
pre_output
->
device_address
());
tensor
->
set_device_address
(
pre_output
->
device_address
());
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
outputs
->
emplace_back
(
tensor
);
outputs
->
emplace_back
(
tensor
);
}
}
}
else
{
}
else
{
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
f42b3bbf
...
@@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
...
@@ -75,7 +75,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
temp_shape
.
emplace_back
(
1
);
temp_shape
.
emplace_back
(
1
);
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
tensor
->
set_padding_type
(
AnfAlgo
::
GetOutputReshapeType
(
node
,
output_index
));
tensor
->
set_padding_type
(
AnfAlgo
::
GetOutputReshapeType
(
node
,
output_index
));
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
tensor
->
SetNeedWait
(
true
);
tensor
->
SetNeedWait
(
true
);
return
tensor
;
return
tensor
;
}
}
...
@@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
...
@@ -96,12 +96,13 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
MS_EXCEPTION_IF_NULL
(
ms_context
);
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
&&
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
&&
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
!=
kGPUDevice
)
{
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
!=
kGPUDevice
)
{
tensor
->
set_need_sync
(
true
);
tensor
->
set_sync_status
(
kNeedSyncDeviceToHostImmediately
);
}
else
{
tensor
->
set_sync_status
(
kNeedSyncDeviceToHost
);
}
}
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
tensor
->
SetNeedWait
(
true
);
tensor
->
SetNeedWait
(
true
);
}
}
tensor
->
set_dirty
(
false
);
return
tensor
;
return
tensor
;
}
}
...
@@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
...
@@ -198,7 +199,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto
*
cur_val
=
static_cast
<
int32_t
*>
(
cur_loop_tensor
->
data_c
());
auto
*
cur_val
=
static_cast
<
int32_t
*>
(
cur_loop_tensor
->
data_c
());
MS_EXCEPTION_IF_NULL
(
cur_val
);
MS_EXCEPTION_IF_NULL
(
cur_val
);
*
cur_val
=
0
;
*
cur_val
=
0
;
cur_loop_tensor
->
set_
dirty
(
tru
e
);
cur_loop_tensor
->
set_
sync_status
(
kNeedSyncHostToDevic
e
);
// set loop_count to zero
// set loop_count to zero
MS_EXCEPTION_IF_NULL
(
inputs
);
MS_EXCEPTION_IF_NULL
(
inputs
);
inputs
->
push_back
(
cur_loop_tensor
);
inputs
->
push_back
(
cur_loop_tensor
);
...
@@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
...
@@ -209,7 +210,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto
*
next_val
=
static_cast
<
int32_t
*>
(
next_loop_tensor
->
data_c
());
auto
*
next_val
=
static_cast
<
int32_t
*>
(
next_loop_tensor
->
data_c
());
MS_EXCEPTION_IF_NULL
(
next_val
);
MS_EXCEPTION_IF_NULL
(
next_val
);
*
next_val
=
0
;
*
next_val
=
0
;
next_loop_tensor
->
set_
dirty
(
tru
e
);
next_loop_tensor
->
set_
sync_status
(
kNeedSyncHostToDevic
e
);
// set loop_count to zero
// set loop_count to zero
MS_EXCEPTION_IF_NULL
(
inputs
);
MS_EXCEPTION_IF_NULL
(
inputs
);
inputs
->
push_back
(
next_loop_tensor
);
inputs
->
push_back
(
next_loop_tensor
);
...
@@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
...
@@ -219,7 +220,7 @@ size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vecto
auto
*
epoch_val
=
static_cast
<
int32_t
*>
(
epoch_tensor
->
data_c
());
auto
*
epoch_val
=
static_cast
<
int32_t
*>
(
epoch_tensor
->
data_c
());
MS_EXCEPTION_IF_NULL
(
epoch_val
);
MS_EXCEPTION_IF_NULL
(
epoch_val
);
*
epoch_val
=
graph
->
current_epoch
();
*
epoch_val
=
graph
->
current_epoch
();
epoch_tensor
->
set_
dirty
(
tru
e
);
epoch_tensor
->
set_
sync_status
(
kNeedSyncHostToDevic
e
);
inputs
->
push_back
(
epoch_tensor
);
inputs
->
push_back
(
epoch_tensor
);
MS_LOG
(
INFO
)
<<
"Load epoch_val:"
<<
*
epoch_val
;
MS_LOG
(
INFO
)
<<
"Load epoch_val:"
<<
*
epoch_val
;
...
@@ -927,7 +928,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
...
@@ -927,7 +928,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
))
{
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
))
{
return
tensor
->
device_address
().
get
()
==
nullptr
||
tensor
->
device_address
()
!=
device_address
;
return
tensor
->
device_address
().
get
()
==
nullptr
||
tensor
->
device_address
()
!=
device_address
;
}
}
if
(
tensor
->
is_dirty
())
{
if
(
tensor
->
NeedSyncHostToDevice
())
{
return
true
;
return
true
;
}
}
if
(
tensor
->
device_address
()
!=
device_address
)
{
if
(
tensor
->
device_address
()
!=
device_address
)
{
...
@@ -976,7 +977,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
...
@@ -976,7 +977,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
MS_LOG
(
EXCEPTION
)
<<
"SyncHostToDevice failed."
;
}
}
}
}
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
}
}
}
}
...
@@ -1124,7 +1125,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
...
@@ -1124,7 +1125,7 @@ void SessionBasic::Summary(KernelGraph *graph) {
tensor
->
data_type
(),
tensor
->
data_c
()))
{
tensor
->
data_type
(),
tensor
->
data_c
()))
{
MS_LOG
(
ERROR
)
<<
"Failed to sync output from device to host."
;
MS_LOG
(
ERROR
)
<<
"Failed to sync output from device to host."
;
}
}
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
params_list
[
output_item
.
first
]
=
tensor
;
params_list
[
output_item
.
first
]
=
tensor
;
}
}
// call callback function here
// call callback function here
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
f42b3bbf
...
@@ -373,7 +373,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
...
@@ -373,7 +373,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
auto
tensor
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input
);
auto
tensor
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input
);
auto
new_tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
tensor
->
data_type
(),
tensor
->
shape
(),
tensor
->
data_ptr
());
auto
new_tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
tensor
->
data_type
(),
tensor
->
shape
(),
tensor
->
data_ptr
());
new_tensor
->
set_device_address
(
tensor
->
device_address
());
new_tensor
->
set_device_address
(
tensor
->
device_address
());
new_tensor
->
set_
dirty
(
tensor
->
is_dirty
());
new_tensor
->
set_
sync_status
(
tensor
->
sync_status
());
result
[
i
]
=
new_tensor
;
result
[
i
]
=
new_tensor
;
}
}
*
status
=
PYNATIVE_SUCCESS
;
*
status
=
PYNATIVE_SUCCESS
;
...
...
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
浏览文件 @
f42b3bbf
...
@@ -162,7 +162,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
...
@@ -162,7 +162,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
}
}
if
(
bound_addresses_
.
find
(
address
)
!=
bound_addresses_
.
end
())
{
if
(
bound_addresses_
.
find
(
address
)
!=
bound_addresses_
.
end
())
{
tensor
->
set_device_address
(
address
);
tensor
->
set_device_address
(
address
);
tensor
->
set_
need_sync
(
true
);
tensor
->
set_
sync_status
(
kNeedSyncDeviceToHostImmediately
);
}
else
{
}
else
{
if
(
infer_type_id
!=
device_type_id
)
{
if
(
infer_type_id
!=
device_type_id
)
{
size_t
type_size
=
GetTypeByte
(
TypeIdToType
(
device_type_id
));
size_t
type_size
=
GetTypeByte
(
TypeIdToType
(
device_type_id
));
...
@@ -170,15 +170,16 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
...
@@ -170,15 +170,16 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
size_t
tensor_size
=
std
::
accumulate
(
data_shape
.
begin
(),
data_shape
.
end
(),
type_size
,
std
::
multiplies
<
size_t
>
());
size_t
tensor_size
=
std
::
accumulate
(
data_shape
.
begin
(),
data_shape
.
end
(),
type_size
,
std
::
multiplies
<
size_t
>
());
address
->
ptr_
=
resource_manager_
.
MemMalloc
(
tensor_size
);
address
->
ptr_
=
resource_manager_
.
MemMalloc
(
tensor_size
);
tensor
->
set_device_address
(
address
);
tensor
->
set_device_address
(
address
);
tensor
->
set_
need_sync
(
true
);
tensor
->
set_
sync_status
(
kNeedSyncDeviceToHostImmediately
);
}
else
{
}
else
{
tensor
->
set_device_address
(
nullptr
);
tensor
->
set_device_address
(
nullptr
);
address
->
ptr_
=
tensor
->
data_c
();
address
->
ptr_
=
tensor
->
data_c
();
tensor
->
set_sync_status
(
kNoNeedSync
);
}
}
address
->
ref_count_
=
INIT_NODE_REF
;
address
->
ref_count_
=
INIT_NODE_REF
;
(
void
)
bound_addresses_
.
insert
(
address
);
(
void
)
bound_addresses_
.
insert
(
address
);
}
}
tensor
->
set_dirty
(
false
);
return
tensor
;
return
tensor
;
}
}
...
@@ -247,7 +248,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
...
@@ -247,7 +248,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const
tensor
->
data_c
()))
{
tensor
->
data_c
()))
{
MS_LOG
(
EXCEPTION
)
<<
"Parameter node sync host to device failed!"
;
MS_LOG
(
EXCEPTION
)
<<
"Parameter node sync host to device failed!"
;
}
}
tensor
->
set_
dirty
(
tru
e
);
tensor
->
set_
sync_status
(
kNeedSyncHostToDevic
e
);
}
}
address
->
ref_count_
=
INIT_NODE_REF
;
address
->
ref_count_
=
INIT_NODE_REF
;
tensor
->
set_device_address
(
address
);
tensor
->
set_device_address
(
address
);
...
...
mindspore/ccsrc/runtime/device/kernel_adjust.cc
浏览文件 @
f42b3bbf
...
@@ -534,7 +534,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
...
@@ -534,7 +534,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
auto
pk_node
=
input_node
->
cast
<
ParameterPtr
>
();
auto
pk_node
=
input_node
->
cast
<
ParameterPtr
>
();
MS_EXCEPTION_IF_NULL
(
tensor
);
MS_EXCEPTION_IF_NULL
(
tensor
);
MS_EXCEPTION_IF_NULL
(
pk_node
);
MS_EXCEPTION_IF_NULL
(
pk_node
);
if
(
tensor
->
is_dirty
()
||
!
pk_node
->
has_default
())
{
if
(
tensor
->
NeedSyncHostToDevice
()
||
!
pk_node
->
has_default
())
{
need_sync
=
true
;
need_sync
=
true
;
}
}
}
}
...
@@ -551,7 +551,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
...
@@ -551,7 +551,7 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph
return
false
;
return
false
;
}
}
}
}
tensor
->
set_
dirty
(
false
);
tensor
->
set_
sync_status
(
kNoNeedSync
);
}
}
return
true
;
return
true
;
}
}
...
...
mindspore/core/ir/tensor.cc
浏览文件 @
f42b3bbf
...
@@ -422,10 +422,9 @@ Tensor::Tensor(const Tensor &tensor)
...
@@ -422,10 +422,9 @@ Tensor::Tensor(const Tensor &tensor)
:
MetaTensor
(
tensor
),
:
MetaTensor
(
tensor
),
init_flag_
(
tensor
.
init_flag_
),
init_flag_
(
tensor
.
init_flag_
),
data_
(
tensor
.
data_
),
data_
(
tensor
.
data_
),
dirty_
(
tensor
.
dirty_
),
id_
(
tensor
.
id_
),
id_
(
tensor
.
id_
),
event_
(
tensor
.
event_
),
event_
(
tensor
.
event_
),
need_sync_
(
tensor
.
need_sync
_
),
sync_status_
(
tensor
.
sync_status
_
),
device_sync_
(
tensor
.
device_sync_
),
device_sync_
(
tensor
.
device_sync_
),
padding_type_
(
tensor
.
padding_type
())
{}
padding_type_
(
tensor
.
padding_type
())
{}
...
@@ -433,10 +432,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
...
@@ -433,10 +432,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
:
MetaTensor
(
data_type
,
tensor
.
shape_
),
:
MetaTensor
(
data_type
,
tensor
.
shape_
),
init_flag_
(
tensor
.
init_flag_
),
init_flag_
(
tensor
.
init_flag_
),
data_
(
MakeTensorData
(
data_type
,
tensor
.
shape_
,
tensor
.
data_
->
data
(),
tensor
.
data_type_
)),
data_
(
MakeTensorData
(
data_type
,
tensor
.
shape_
,
tensor
.
data_
->
data
(),
tensor
.
data_type_
)),
dirty_
(
tensor
.
dirty_
),
id_
(
tensor
.
id_
),
id_
(
tensor
.
id_
),
event_
(
tensor
.
event_
),
event_
(
tensor
.
event_
),
need_sync_
(
tensor
.
need_sync
_
),
sync_status_
(
tensor
.
sync_status
_
),
device_sync_
(
tensor
.
device_sync_
),
device_sync_
(
tensor
.
device_sync_
),
padding_type_
(
tensor
.
padding_type
())
{}
padding_type_
(
tensor
.
padding_type
())
{}
...
@@ -483,12 +481,11 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
...
@@ -483,12 +481,11 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
Tensor
&
Tensor
::
AssignValue
(
const
Tensor
&
tensor
)
{
Tensor
&
Tensor
::
AssignValue
(
const
Tensor
&
tensor
)
{
if
(
this
!=
&
tensor
)
{
if
(
this
!=
&
tensor
)
{
MetaTensor
::
operator
=
(
tensor
);
MetaTensor
::
operator
=
(
tensor
);
dirty_
=
tensor
.
dirty_
;
device_sync_
=
tensor
.
device_sync_
;
device_sync_
=
tensor
.
device_sync_
;
data_
=
tensor
.
data_
;
data_
=
tensor
.
data_
;
id_
=
tensor
.
id_
;
id_
=
tensor
.
id_
;
event_
=
tensor
.
event_
;
event_
=
tensor
.
event_
;
need_sync_
=
tensor
.
need_sync
_
;
sync_status_
=
tensor
.
sync_status
_
;
padding_type_
=
tensor
.
padding_type_
;
padding_type_
=
tensor
.
padding_type_
;
}
}
return
*
this
;
return
*
this
;
...
...
mindspore/core/ir/tensor.h
浏览文件 @
f42b3bbf
...
@@ -36,7 +36,7 @@
...
@@ -36,7 +36,7 @@
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace
mindspore
{
namespace
mindspore
{
// brief mindspore::tensor namespace
// brief mindspore::tensor namespace
//
enum
TensorSyncStatus
{
kNoNeedSync
,
kNeedSyncHostToDevice
,
kNeedSyncDeviceToHost
,
kNeedSyncDeviceToHostImmediately
};
// A sub namespace in ME to support tensor related definition.
// A sub namespace in ME to support tensor related definition.
namespace
tensor
{
namespace
tensor
{
// Tensor data interface.
// Tensor data interface.
...
@@ -260,9 +260,6 @@ class Tensor : public MetaTensor {
...
@@ -260,9 +260,6 @@ class Tensor : public MetaTensor {
bool
is_init
()
const
{
return
init_flag_
;
}
bool
is_init
()
const
{
return
init_flag_
;
}
void
set_init_flag
(
bool
flag
)
{
init_flag_
=
flag
;
}
void
set_init_flag
(
bool
flag
)
{
init_flag_
=
flag
;
}
bool
is_dirty
()
const
{
return
dirty_
;
}
void
set_dirty
(
const
bool
dirty
)
{
dirty_
=
dirty
;
}
DeviceSyncPtr
device_address
()
const
{
return
device_sync_
;
}
DeviceSyncPtr
device_address
()
const
{
return
device_sync_
;
}
void
set_device_address
(
const
DeviceSyncPtr
&
device_sync
)
{
device_sync_
=
device_sync
;
}
void
set_device_address
(
const
DeviceSyncPtr
&
device_sync
)
{
device_sync_
=
device_sync
;
}
void
set_padding_type
(
std
::
vector
<
Axis
>
padding_type
)
{
padding_type_
=
padding_type
;
}
void
set_padding_type
(
std
::
vector
<
Axis
>
padding_type
)
{
padding_type_
=
padding_type
;
}
...
@@ -293,17 +290,22 @@ class Tensor : public MetaTensor {
...
@@ -293,17 +290,22 @@ class Tensor : public MetaTensor {
event_
==
nullptr
;
event_
==
nullptr
;
}
}
void
set_need_sync
(
bool
need_sync
)
{
need_sync_
=
need_sync
;
}
void
set_sync_status
(
TensorSyncStatus
sync_status
)
{
sync_status_
=
sync_status
;
}
TensorSyncStatus
sync_status
()
const
{
return
sync_status_
;
}
bool
NeedSyncDeviceToHostImmediately
()
const
{
return
sync_status_
==
kNeedSyncDeviceToHostImmediately
;
}
bool
NeedSyncDeviceToHost
()
const
{
return
sync_status_
==
kNeedSyncDeviceToHost
;
}
bool
need_sync
()
const
{
return
need_sync_
;
}
bool
NeedSyncHostToDevice
()
const
{
return
sync_status_
==
kNeedSyncHostToDevice
;
}
private:
private:
bool
init_flag_
{
false
};
bool
init_flag_
{
false
};
TensorDataPtr
data_
{
nullptr
};
TensorDataPtr
data_
{
nullptr
};
bool
dirty_
{
true
};
std
::
string
id_
{
""
};
std
::
string
id_
{
""
};
std
::
shared_ptr
<
WaitEvent
>
event_
{
nullptr
};
std
::
shared_ptr
<
WaitEvent
>
event_
{
nullptr
};
bool
need_sync_
{
fals
e
};
TensorSyncStatus
sync_status_
{
kNeedSyncHostToDevic
e
};
DeviceSyncPtr
device_sync_
{
nullptr
};
DeviceSyncPtr
device_sync_
{
nullptr
};
std
::
vector
<
Axis
>
padding_type_
;
std
::
vector
<
Axis
>
padding_type_
;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录