Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2d973c95
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2d973c95
编写于
7月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3161 Set output value for dynamic graph.
Merge pull request !3161 from flywind/output
上级
1e88d64b
570da089
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
291 addition
and
57 deletion
+291
-57
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+11
-4
mindspore/ccsrc/backend/session/ascend_session.h
mindspore/ccsrc/backend/session/ascend_session.h
+2
-1
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+4
-3
mindspore/ccsrc/backend/session/gpu_session.h
mindspore/ccsrc/backend/session/gpu_session.h
+2
-1
mindspore/ccsrc/backend/session/kernel_graph.cc
mindspore/ccsrc/backend/session/kernel_graph.cc
+36
-1
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
+32
-0
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h
+1
-0
mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h
...re/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h
+5
-1
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
...ore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
+5
-1
mindspore/ccsrc/pipeline/pynative/base.h
mindspore/ccsrc/pipeline/pynative/base.h
+2
-0
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+60
-9
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
+7
-1
mindspore/ccsrc/runtime/device/kernel_runtime.cc
mindspore/ccsrc/runtime/device/kernel_runtime.cc
+78
-28
mindspore/ccsrc/runtime/device/kernel_runtime.h
mindspore/ccsrc/runtime/device/kernel_runtime.h
+3
-1
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+21
-0
mindspore/ccsrc/utils/convert_utils.h
mindspore/ccsrc/utils/convert_utils.h
+3
-0
mindspore/core/ir/anf.h
mindspore/core/ir/anf.h
+14
-2
mindspore/core/ir/func_graph_cloner.cc
mindspore/core/ir/func_graph_cloner.cc
+3
-0
tests/ut/python/pynative_mode/test_high_order_grad.py
tests/ut/python/pynative_mode/test_high_order_grad.py
+1
-1
tests/vm_impl/vm_me.py
tests/vm_impl/vm_me.py
+1
-3
未找到文件。
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
2d973c95
...
@@ -608,14 +608,20 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
...
@@ -608,14 +608,20 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"Run op "
<<
op_run_info
.
op_name
<<
" start!"
;
MS_LOG
(
INFO
)
<<
"Run op "
<<
op_run_info
.
op_name
<<
" start!"
;
// malloc mem
// malloc mem
RunOpMemoryAlloc
(
input_tensors
,
graph
.
get
());
RunOpMemoryAlloc
(
op_run_info
.
value
,
input_tensors
,
graph
.
get
());
// load input data to device
// load input data to device
LoadInputData
(
graph
,
input_tensors
);
LoadInputData
(
graph
,
input_tensors
);
// run op
// run op
RunOpExecTask
(
graph
);
RunOpExecTask
(
graph
);
// get output
// get output
VectorRef
outputs
;
VectorRef
outputs
;
UpdateOutputs
(
graph
,
&
outputs
,
input_tensors
);
if
(
op_run_info
.
value
!=
nullptr
)
{
std
::
vector
<
tensor
::
TensorPtr
>
pre_output_tensors
;
TensorValueToTensor
(
op_run_info
.
value
,
&
pre_output_tensors
);
std
::
copy
(
pre_output_tensors
.
begin
(),
pre_output_tensors
.
end
(),
std
::
back_inserter
(
outputs
));
}
else
{
UpdateOutputs
(
graph
,
&
outputs
,
input_tensors
);
}
// trans output to tuple
// trans output to tuple
auto
output_tensors
=
TransformBaseRefListToTuple
(
outputs
);
auto
output_tensors
=
TransformBaseRefListToTuple
(
outputs
);
if
(
!
utils
::
isa
<
PyObjectRef
>
(
output_tensors
)
||
if
(
!
utils
::
isa
<
PyObjectRef
>
(
output_tensors
)
||
...
@@ -744,14 +750,15 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
...
@@ -744,14 +750,15 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
MS_LOG
(
INFO
)
<<
"Finish!"
;
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
}
void
AscendSession
::
RunOpMemoryAlloc
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
void
AscendSession
::
RunOpMemoryAlloc
(
const
ValuePtr
&
pre_output_value
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
{
KernelGraph
*
kernel_graph
)
const
{
MS_LOG
(
INFO
)
<<
"Start memory alloc!"
;
MS_LOG
(
INFO
)
<<
"Start memory alloc!"
;
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
opt
::
RemoveNopNode
(
kernel_graph
);
opt
::
RemoveNopNode
(
kernel_graph
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
runtime_instance
->
RunOpAssignMemory
(
input_tensors
,
kernel_graph
);
runtime_instance
->
RunOpAssignMemory
(
pre_output_value
,
input_tensors
,
kernel_graph
);
MS_LOG
(
INFO
)
<<
"Finish!"
;
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
}
...
...
mindspore/ccsrc/backend/session/ascend_session.h
浏览文件 @
2d973c95
...
@@ -79,7 +79,8 @@ class AscendSession : public SessionBasic {
...
@@ -79,7 +79,8 @@ class AscendSession : public SessionBasic {
void
AssignStream
(
NotNull
<
KernelGraphPtr
>
kernel_graph
)
const
;
void
AssignStream
(
NotNull
<
KernelGraphPtr
>
kernel_graph
)
const
;
void
BuildKernel
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
BuildKernel
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
MemoryAlloc
(
KernelGraph
*
kernel_graph
)
const
;
void
MemoryAlloc
(
KernelGraph
*
kernel_graph
)
const
;
void
RunOpMemoryAlloc
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
;
void
RunOpMemoryAlloc
(
const
ValuePtr
&
pre_output_value
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
;
void
RunOpMemoryClear
(
const
KernelGraph
*
kernel_graph
)
const
;
void
RunOpMemoryClear
(
const
KernelGraph
*
kernel_graph
)
const
;
void
GenerateTaskInfo
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
GenerateTaskInfo
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
LoadTask
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
void
LoadTask
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
2d973c95
...
@@ -102,12 +102,13 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const {
...
@@ -102,12 +102,13 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const {
runtime_instance
->
AssignMemory
(
kernel_graph
);
runtime_instance
->
AssignMemory
(
kernel_graph
);
}
}
void
GPUSession
::
RunOpAllocateMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
void
GPUSession
::
RunOpAllocateMemory
(
const
ValuePtr
&
pre_output_value
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
{
KernelGraph
*
kernel_graph
)
const
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetSingleKernelRuntime
(
kGPUDevice
,
device_id_
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetSingleKernelRuntime
(
kGPUDevice
,
device_id_
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
runtime_instance
->
RunOpAssignMemory
(
input_tensors
,
kernel_graph
);
runtime_instance
->
RunOpAssignMemory
(
pre_output_value
,
input_tensors
,
kernel_graph
);
}
}
void
GPUSession
::
RunOpClearMemory
(
KernelGraph
*
kernel_graph
)
const
{
void
GPUSession
::
RunOpClearMemory
(
KernelGraph
*
kernel_graph
)
const
{
...
@@ -292,7 +293,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
...
@@ -292,7 +293,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
// Remove NoOp from execution graph
// Remove NoOp from execution graph
opt
::
RemoveNopNode
(
kernel_graph
.
get
());
opt
::
RemoveNopNode
(
kernel_graph
.
get
());
RunOpAllocateMemory
(
input_tensors
,
kernel_graph
.
get
());
RunOpAllocateMemory
(
op_run_info
.
value
,
input_tensors
,
kernel_graph
.
get
());
// Execute the computation
// Execute the computation
LoadInputData
(
kernel_graph
,
input_tensors
);
LoadInputData
(
kernel_graph
,
input_tensors
);
Execute
(
kernel_graph
);
Execute
(
kernel_graph
);
...
...
mindspore/ccsrc/backend/session/gpu_session.h
浏览文件 @
2d973c95
...
@@ -59,7 +59,8 @@ class GPUSession : public SessionBasic {
...
@@ -59,7 +59,8 @@ class GPUSession : public SessionBasic {
void
AllocateMemory
(
KernelGraph
*
kernel_graph
)
const
;
void
AllocateMemory
(
KernelGraph
*
kernel_graph
)
const
;
void
RunOpAllocateMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
;
void
RunOpAllocateMemory
(
const
ValuePtr
&
pre_output_value
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
KernelGraph
*
kernel_graph
)
const
;
void
RunOpClearMemory
(
KernelGraph
*
kernel_graph
)
const
;
void
RunOpClearMemory
(
KernelGraph
*
kernel_graph
)
const
;
...
...
mindspore/ccsrc/backend/session/kernel_graph.cc
浏览文件 @
2d973c95
...
@@ -95,6 +95,38 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
...
@@ -95,6 +95,38 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
}
}
return
false
;
return
false
;
}
}
void
SyncDeviceInfoToValueNode
(
const
ValueNodePtr
&
value_node
,
std
::
vector
<
std
::
string
>
*
device_formats
,
std
::
vector
<
TypeId
>
*
device_types
)
{
MS_EXCEPTION_IF_NULL
(
value_node
);
MS_EXCEPTION_IF_NULL
(
device_formats
);
MS_EXCEPTION_IF_NULL
(
device_types
);
ValuePtr
value
=
value_node
->
value
();
std
::
vector
<
tensor
::
TensorPtr
>
tensors
;
TensorValueToTensor
(
value
,
&
tensors
);
if
(
!
tensors
.
empty
())
{
if
(
tensors
.
size
()
!=
AnfAlgo
::
GetOutputTensorNum
(
value_node
))
{
MS_LOG
(
EXCEPTION
)
<<
"The size of tensors converted from value ["
<<
tensors
.
size
()
<<
"] is not equal to output size of value node ["
<<
AnfAlgo
::
GetOutputTensorNum
(
value_node
)
<<
"]"
;
}
device_formats
->
clear
();
device_types
->
clear
();
for
(
const
auto
&
tensor
:
tensors
)
{
MS_EXCEPTION_IF_NULL
(
tensor
);
auto
device_sync
=
tensor
->
device_address
();
if
(
device_sync
!=
nullptr
)
{
auto
device_address
=
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
device_sync
);
MS_EXCEPTION_IF_NULL
(
device_address
);
device_formats
->
emplace_back
(
device_address
->
format
());
device_types
->
emplace_back
(
device_address
->
type_id
());
continue
;
}
device_formats
->
emplace_back
(
kOpFormat_DEFAULT
);
device_types
->
emplace_back
(
kTypeUnknown
);
}
}
}
}
// namespace
}
// namespace
AnfNodePtr
KernelGraph
::
MakeValueNode
(
const
AnfNodePtr
&
node
)
{
AnfNodePtr
KernelGraph
::
MakeValueNode
(
const
AnfNodePtr
&
node
)
{
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
...
@@ -347,10 +379,12 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
...
@@ -347,10 +379,12 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// set the format of value_node to DEFAULT_FORMAT
// set the format of value_node to DEFAULT_FORMAT
std
::
vector
<
TypeId
>
types
;
std
::
vector
<
TypeId
>
types
;
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_DEFAULT
})
;
std
::
vector
<
std
::
string
>
formats
=
{
kOpFormat_DEFAULT
}
;
if
(
node
->
isa
<
ValueNode
>
())
{
if
(
node
->
isa
<
ValueNode
>
())
{
kernel_info
->
SetFeatureMapFlag
(
false
);
kernel_info
->
SetFeatureMapFlag
(
false
);
types
.
emplace_back
(
kTypeUnknown
);
types
.
emplace_back
(
kTypeUnknown
);
auto
value_node
=
node
->
cast
<
ValueNodePtr
>
();
SyncDeviceInfoToValueNode
(
value_node
,
&
formats
,
&
types
);
}
}
if
(
node
->
isa
<
Parameter
>
())
{
if
(
node
->
isa
<
Parameter
>
())
{
auto
parameter
=
node
->
cast
<
ParameterPtr
>
();
auto
parameter
=
node
->
cast
<
ParameterPtr
>
();
...
@@ -360,6 +394,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
...
@@ -360,6 +394,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
types
.
push_back
(
is_weight
?
kTypeUnknown
:
AnfAlgo
::
GetOutputInferDataType
(
parameter
,
0
));
types
.
push_back
(
is_weight
?
kTypeUnknown
:
AnfAlgo
::
GetOutputInferDataType
(
parameter
,
0
));
}
}
// set parameter initaial device data type
// set parameter initaial device data type
kernel_build_info_builder
->
SetOutputsFormat
(
formats
);
kernel_build_info_builder
->
SetOutputsDeviceType
(
types
);
kernel_build_info_builder
->
SetOutputsDeviceType
(
types
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
node
.
get
());
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
node
.
get
());
}
}
...
...
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc
浏览文件 @
2d973c95
...
@@ -216,6 +216,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
...
@@ -216,6 +216,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceGradFpropApp
>
(
cnode_morph
->
debug_info
()));
TraceManager
::
DebugTrace
(
std
::
make_shared
<
TraceGradFpropApp
>
(
cnode_morph
->
debug_info
()));
auto
k_app
=
k_graph_
->
NewCNode
(
inputs
);
auto
k_app
=
k_graph_
->
NewCNode
(
inputs
);
TraceManager
::
EndTrace
();
TraceManager
::
EndTrace
();
ReplaceEquivdout
(
k_app
,
cnode_morph
->
forward
());
for
(
size_t
i
=
0
;
i
<
param_adjoints
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_adjoints
.
size
();
++
i
)
{
param_adjoints
[
i
]
->
RegisterKUser
(
k_app
,
i
);
param_adjoints
[
i
]
->
RegisterKUser
(
k_app
,
i
);
}
}
...
@@ -237,6 +238,37 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
...
@@ -237,6 +238,37 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return
node_adjoint
;
return
node_adjoint
;
}
}
void
DFunctor
::
ReplaceEquivdout
(
const
CNodePtr
&
cnode
,
const
ValuePtr
&
forward
)
{
if
(
forward
==
nullptr
)
{
return
;
}
auto
&
input
=
cnode
->
input
(
0
);
if
(
!
IsValueNode
<
FuncGraph
>
(
input
))
{
return
;
}
auto
fg
=
GetValueNode
<
FuncGraphPtr
>
(
input
);
auto
output
=
fg
->
output
();
if
(
!
output
->
isa
<
CNode
>
())
{
return
;
}
auto
cnode_output
=
output
->
cast
<
CNodePtr
>
();
auto
&
cnode_input
=
cnode_output
->
input
(
1
);
if
(
!
cnode_input
->
isa
<
CNode
>
())
{
return
;
}
auto
&
input_fg
=
cnode_output
->
input
(
2
);
if
(
!
IsValueNode
<
FuncGraph
>
(
input_fg
))
{
return
;
}
auto
equivdout
=
cnode_input
->
cast
<
CNodePtr
>
();
auto
func_graph
=
GetValueNode
<
FuncGraphPtr
>
(
input_fg
);
auto
manager
=
Manage
({
fg
,
func_graph
},
false
);
MS_LOG
(
DEBUG
)
<<
"Replace: "
<<
equivdout
->
ToString
()
<<
" with "
<<
forward
;
auto
value_node
=
NewValueNode
(
forward
);
value_node
->
set_has_new_value
(
true
);
manager
->
Replace
(
equivdout
,
value_node
);
}
bool
DFunctor
::
IsFreeMorphism
(
const
AnfNodePtr
&
node
)
{
bool
DFunctor
::
IsFreeMorphism
(
const
AnfNodePtr
&
node
)
{
// Do not care about non-CNode
// Do not care about non-CNode
if
(
!
node
->
isa
<
CNode
>
())
{
if
(
!
node
->
isa
<
CNode
>
())
{
...
...
mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h
浏览文件 @
2d973c95
...
@@ -95,6 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
...
@@ -95,6 +95,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
// Update k hole with adjoint_definition, only applied in recursive case.
// Update k hole with adjoint_definition, only applied in recursive case.
void
UpdateAdjoint
(
const
AdjointPtr
&
adjoint_definition
);
void
UpdateAdjoint
(
const
AdjointPtr
&
adjoint_definition
);
void
CallDoutHoleOnTape
();
void
CallDoutHoleOnTape
();
void
ReplaceEquivdout
(
const
CNodePtr
&
cnode
,
const
ValuePtr
&
forward
);
std
::
unordered_map
<
AnfNodePtr
,
AdjointPtr
>
anfnode_to_adjoin_
;
std
::
unordered_map
<
AnfNodePtr
,
AdjointPtr
>
anfnode_to_adjoin_
;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
...
...
mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h
浏览文件 @
2d973c95
...
@@ -88,7 +88,9 @@ class GetitemConstEliminater : public AnfVisitor {
...
@@ -88,7 +88,9 @@ class GetitemConstEliminater : public AnfVisitor {
AnfVisitor
::
Match
(
prim
::
kPrimListGetItem
,
{
IsVNode
,
IsVNode
})(
node
);
AnfVisitor
::
Match
(
prim
::
kPrimListGetItem
,
{
IsVNode
,
IsVNode
})(
node
);
if
(
is_match_
)
{
if
(
is_match_
)
{
return
NewValueNode
((
*
tuple_
)[
id_
]);
auto
out
=
NewValueNode
((
*
tuple_
)[
id_
]);
out
->
set_has_new_value
(
has_new_value_
);
return
out
;
}
}
return
nullptr
;
return
nullptr
;
}
}
...
@@ -96,6 +98,7 @@ class GetitemConstEliminater : public AnfVisitor {
...
@@ -96,6 +98,7 @@ class GetitemConstEliminater : public AnfVisitor {
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
void
Visit
(
const
ValueNodePtr
&
vnode
)
override
{
if
(
IsValueNode
<
ValueTuple
>
(
vnode
))
{
if
(
IsValueNode
<
ValueTuple
>
(
vnode
))
{
tuple_
=
GetValueNode
<
ValueTuplePtr
>
(
vnode
);
tuple_
=
GetValueNode
<
ValueTuplePtr
>
(
vnode
);
has_new_value_
=
vnode
->
has_new_value
();
}
}
if
(
tuple_
!=
nullptr
&&
IsValueNode
<
Int32Imm
>
(
vnode
))
{
if
(
tuple_
!=
nullptr
&&
IsValueNode
<
Int32Imm
>
(
vnode
))
{
id_
=
IntToSize
(
GetValue
<
int
>
(
vnode
->
value
()));
id_
=
IntToSize
(
GetValue
<
int
>
(
vnode
->
value
()));
...
@@ -115,6 +118,7 @@ class GetitemConstEliminater : public AnfVisitor {
...
@@ -115,6 +118,7 @@ class GetitemConstEliminater : public AnfVisitor {
bool
is_match_
{
false
};
bool
is_match_
{
false
};
size_t
id_
{
0
};
size_t
id_
{
0
};
ValueTuplePtr
tuple_
{
nullptr
};
ValueTuplePtr
tuple_
{
nullptr
};
bool
has_new_value_
{
false
};
};
};
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
浏览文件 @
2d973c95
...
@@ -205,7 +205,11 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
...
@@ -205,7 +205,11 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
AbstractBasePtr
AnalysisEngine
::
EvalValueNode
(
const
ValueNodePtr
&
value_node
,
const
AnfNodeConfigPtr
&
conf
)
{
AbstractBasePtr
AnalysisEngine
::
EvalValueNode
(
const
ValueNodePtr
&
value_node
,
const
AnfNodeConfigPtr
&
conf
)
{
MS_EXCEPTION_IF_NULL
(
conf
);
MS_EXCEPTION_IF_NULL
(
conf
);
MS_EXCEPTION_IF_NULL
(
value_node
);
MS_EXCEPTION_IF_NULL
(
value_node
);
return
ToAbstract
(
value_node
->
value
(),
conf
->
context
(),
conf
);
auto
out
=
ToAbstract
(
value_node
->
value
(),
conf
->
context
(),
conf
);
if
(
value_node
->
has_new_value
())
{
out
=
out
->
Broaden
();
}
return
out
;
}
}
EvalResultPtr
AnalysisEngine
::
EvalCNode
(
const
CNodePtr
&
cnode
,
const
AnfNodeConfigPtr
&
conf
)
{
EvalResultPtr
AnalysisEngine
::
EvalCNode
(
const
CNodePtr
&
cnode
,
const
AnfNodeConfigPtr
&
conf
)
{
...
...
mindspore/ccsrc/pipeline/pynative/base.h
浏览文件 @
2d973c95
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <unordered_set>
#include <unordered_set>
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
#include "ir/anf.h"
#include "ir/primitive_py.h"
#include "ir/primitive_py.h"
#include "abstract/abstract_value.h"
#include "abstract/abstract_value.h"
...
@@ -51,6 +52,7 @@ struct OpExecInfo {
...
@@ -51,6 +52,7 @@ struct OpExecInfo {
PrimitivePyPtr
py_primitive
;
PrimitivePyPtr
py_primitive
;
std
::
string
op_name
;
std
::
string
op_name
;
AbstractBasePtr
abstract
;
AbstractBasePtr
abstract
;
ValuePtr
value
=
nullptr
;
py
::
tuple
op_inputs
;
py
::
tuple
op_inputs
;
py
::
tuple
inputs_mask
;
py
::
tuple
inputs_mask
;
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
2d973c95
...
@@ -111,7 +111,7 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
...
@@ -111,7 +111,7 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
return
converted_ret
;
return
converted_ret
;
}
}
std
::
string
GetId
(
const
py
::
object
&
obj
)
{
st
atic
st
d
::
string
GetId
(
const
py
::
object
&
obj
)
{
py
::
object
to_process
=
obj
;
py
::
object
to_process
=
obj
;
std
::
string
prefix
=
""
;
std
::
string
prefix
=
""
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
to_process
))
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
to_process
))
{
...
@@ -141,6 +141,11 @@ std::string GetId(const py::object &obj) {
...
@@ -141,6 +141,11 @@ std::string GetId(const py::object &obj) {
return
py
::
cast
<
std
::
string
>
(
ret
);
return
py
::
cast
<
std
::
string
>
(
ret
);
}
}
static
std
::
string
GetOpId
(
const
OpExecInfoPtr
&
op_exec_info
)
{
auto
id
=
GetId
(
op_exec_info
->
py_primitive
->
GetPyObj
());
return
id
;
}
py
::
object
GetTupleObj
(
const
py
::
object
&
obj
)
{
py
::
object
GetTupleObj
(
const
py
::
object
&
obj
)
{
py
::
module
mod
=
parse
::
python_adapter
::
GetPyModule
(
parse
::
PYTHON_MOD_PARSE_MODULE
);
py
::
module
mod
=
parse
::
python_adapter
::
GetPyModule
(
parse
::
PYTHON_MOD_PARSE_MODULE
);
py
::
object
obj_tuple
=
parse
::
python_adapter
::
CallPyModFn
(
mod
,
parse
::
PYTHON_MOD_GET_DEFAULT_INPUT
,
obj
);
py
::
object
obj_tuple
=
parse
::
python_adapter
::
CallPyModFn
(
mod
,
parse
::
PYTHON_MOD_GET_DEFAULT_INPUT
,
obj
);
...
@@ -317,6 +322,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
...
@@ -317,6 +322,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
}
}
op_exec_info
->
py_primitive
=
prim
;
op_exec_info
->
py_primitive
=
prim
;
op_exec_info
->
op_attrs
=
py
::
getattr
(
args
[
PY_PRIM
],
"attrs"
);
op_exec_info
->
op_attrs
=
py
::
getattr
(
args
[
PY_PRIM
],
"attrs"
);
op_exec_info
->
value
=
PynativeExecutor
::
GetInstance
()
->
GetForwardValue
(
op_exec_info
);
if
(
op_exec_info
->
op_inputs
.
size
()
!=
op_exec_info
->
inputs_mask
.
size
())
{
if
(
op_exec_info
->
op_inputs
.
size
()
!=
op_exec_info
->
inputs_mask
.
size
())
{
MS_LOG
(
ERROR
)
<<
"Op:"
<<
op_exec_info
->
op_name
<<
" inputs size not equal op_mask"
;
MS_LOG
(
ERROR
)
<<
"Op:"
<<
op_exec_info
->
op_name
<<
" inputs size not equal op_mask"
;
return
nullptr
;
return
nullptr
;
...
@@ -606,7 +612,20 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
...
@@ -606,7 +612,20 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
return
result
;
return
result
;
}
}
AnfNodePtr
PynativeExecutor
::
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
)
{
ValuePtr
PynativeExecutor
::
GetForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
)
{
auto
id
=
GetOpId
(
op_exec_info
);
auto
op
=
id
;
op
.
append
(
std
::
to_string
(
op_id_map_
[
id
]));
auto
iter
=
op_forward_map_
.
find
(
op
);
if
(
iter
!=
op_forward_map_
.
end
())
{
++
op_id_map_
[
id
];
MS_LOG
(
DEBUG
)
<<
"Get: "
<<
op_exec_info
->
op_name
<<
"("
<<
op
<<
"), "
<<
iter
->
second
;
return
iter
->
second
;
}
return
nullptr
;
}
CNodePtr
PynativeExecutor
::
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
)
{
if
(
!
grad_flag_
||
graph_info_map_
.
empty
())
{
if
(
!
grad_flag_
||
graph_info_map_
.
empty
())
{
return
nullptr
;
return
nullptr
;
}
}
...
@@ -645,6 +664,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const
...
@@ -645,6 +664,34 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const
return
cnode
;
return
cnode
;
}
}
void
PynativeExecutor
::
SaveOpForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
,
const
ValuePtr
&
value
)
{
auto
id
=
GetOpId
(
op_exec_info
);
auto
op
=
id
;
op
.
append
(
std
::
to_string
(
op_id_map_
[
id
]));
auto
iter
=
op_forward_map_
.
find
(
op
);
if
(
iter
!=
op_forward_map_
.
end
())
{
return
;
}
op_forward_map_
[
op
]
=
value
;
++
op_id_map_
[
id
];
MS_LOG
(
DEBUG
)
<<
"Save: "
<<
op_exec_info
->
op_name
<<
"("
<<
op
<<
"), "
<<
value
;
}
void
PynativeExecutor
::
SaveAllResult
(
const
OpExecInfoPtr
&
op_exec_info
,
const
CNodePtr
&
cnode
,
const
py
::
tuple
&
out
)
{
if
(
!
grad_flag_
||
op_exec_info
->
value
!=
nullptr
)
{
return
;
}
py
::
object
out_real
=
out
;
if
(
out
.
size
()
==
1
)
{
out_real
=
out
[
0
];
}
auto
value
=
PyAttrValue
(
out_real
);
if
(
cnode
!=
nullptr
)
{
cnode
->
set_forward
(
value
);
}
SaveOpForwardValue
(
op_exec_info
,
value
);
}
AnfNodePtr
PynativeExecutor
::
GetObjNode
(
const
py
::
object
&
obj
)
{
AnfNodePtr
PynativeExecutor
::
GetObjNode
(
const
py
::
object
&
obj
)
{
auto
&
out
=
graph_info_map_
[
curr_g_
].
obj_node_map
[
GetId
(
obj
)];
auto
&
out
=
graph_info_map_
[
curr_g_
].
obj_node_map
[
GetId
(
obj
)];
if
(
out
.
second
.
size
()
==
1
&&
out
.
second
[
0
]
==
-
1
)
{
if
(
out
.
second
.
size
()
==
1
&&
out
.
second
[
0
]
==
-
1
)
{
...
@@ -657,6 +704,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
...
@@ -657,6 +704,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
node
=
curr_g_
->
NewCNode
(
tuple_get_item_inputs
);
node
=
curr_g_
->
NewCNode
(
tuple_get_item_inputs
);
}
}
MS_LOG
(
DEBUG
)
<<
"GetObjNode output"
<<
node
->
DebugString
(
6
);
MS_LOG
(
DEBUG
)
<<
"GetObjNode output"
<<
node
->
DebugString
(
6
);
node
->
cast
<
CNodePtr
>
()
->
set_forward
(
PyAttrValue
(
obj
));
return
node
;
return
node
;
}
}
...
@@ -690,11 +738,12 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
...
@@ -690,11 +738,12 @@ py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) {
return
err_ret
;
return
err_ret
;
}
}
auto
node
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
op_exec_info
,
args
,
result
);
auto
c
node
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
op_exec_info
,
args
,
result
);
if
(
node
!=
nullptr
)
{
if
(
c
node
!=
nullptr
)
{
node
->
set_abstract
(
op_exec_info
->
abstract
);
c
node
->
set_abstract
(
op_exec_info
->
abstract
);
MS_LOG
(
DEBUG
)
<<
"RunOp MakeCnode,new node is: "
<<
node
->
DebugString
();
MS_LOG
(
DEBUG
)
<<
"RunOp MakeCnode,new node is: "
<<
c
node
->
DebugString
();
}
}
PynativeExecutor
::
GetInstance
()
->
SaveAllResult
(
op_exec_info
,
cnode
,
result
);
MS_LOG
(
DEBUG
)
<<
"RunOp end"
;
MS_LOG
(
DEBUG
)
<<
"RunOp end"
;
return
result
;
return
result
;
}
}
...
@@ -1072,7 +1121,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
...
@@ -1072,7 +1121,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
void
PynativeExecutor
::
Clear
(
const
std
::
string
&
flag
)
{
void
PynativeExecutor
::
Clear
(
const
std
::
string
&
flag
)
{
if
(
!
flag
.
empty
())
{
if
(
!
flag
.
empty
())
{
MS_LOG
(
INFO
)
<<
"Clear res"
;
MS_LOG
(
DEBUG
)
<<
"Clear res"
;
(
void
)
graph_map_
.
erase
(
flag
);
(
void
)
graph_map_
.
erase
(
flag
);
(
void
)
cell_graph_map_
.
erase
(
flag
);
(
void
)
cell_graph_map_
.
erase
(
flag
);
Clean
();
Clean
();
...
@@ -1084,17 +1133,19 @@ void PynativeExecutor::Clear(const std::string &flag) {
...
@@ -1084,17 +1133,19 @@ void PynativeExecutor::Clear(const std::string &flag) {
return
;
return
;
}
}
MS_LOG
(
INFO
)
<<
"Clear"
;
MS_LOG
(
DEBUG
)
<<
"Clear"
;
top_g_
=
nullptr
;
top_g_
=
nullptr
;
curr_g_
=
nullptr
;
curr_g_
=
nullptr
;
graph_info_map_
.
clear
();
graph_info_map_
.
clear
();
op_id_map_
.
clear
();
std
::
stack
<
FuncGraphPtr
>
().
swap
(
graph_p_
);
std
::
stack
<
FuncGraphPtr
>
().
swap
(
graph_p_
);
}
}
void
PynativeExecutor
::
Clean
()
{
void
PynativeExecutor
::
Clean
()
{
MS_LOG
(
INFO
)
<<
"Clean all res"
;
MS_LOG
(
DEBUG
)
<<
"Clean all res"
;
Clear
();
Clear
();
grad_flag_
=
false
;
grad_flag_
=
false
;
op_forward_map_
.
clear
();
df_builder_
=
nullptr
;
df_builder_
=
nullptr
;
ad
::
CleanRes
();
ad
::
CleanRes
();
pipeline
::
ReclaimOptimizer
();
pipeline
::
ReclaimOptimizer
();
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
浏览文件 @
2d973c95
...
@@ -95,7 +95,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -95,7 +95,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
std
::
vector
<
int
>
index
)
{
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
,
std
::
vector
<
int
>
index
)
{
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
graph_info_map_
[
g
].
obj_node_map
[
obj
]
=
std
::
make_pair
(
node
,
index
);
}
}
AnfNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
CNodePtr
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
);
ValuePtr
GetForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
);
void
SaveOpForwardValue
(
const
OpExecInfoPtr
&
op_exec_info
,
const
ValuePtr
&
value
);
void
SaveForwardResult
(
const
CNodePtr
&
cnode
,
const
py
::
object
&
out
);
void
SaveAllResult
(
const
OpExecInfoPtr
&
op_exec_info
,
const
CNodePtr
&
cnode
,
const
py
::
tuple
&
out
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
py
::
object
Run
(
const
py
::
tuple
&
args
,
const
py
::
object
&
phase
);
void
Pushp
();
void
Pushp
();
...
@@ -116,6 +120,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -116,6 +120,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
graph_map_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
graph_map_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
cell_graph_map_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
cell_graph_map_
;
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
op_forward_map_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
op_id_map_
;
std
::
stack
<
FuncGraphPtr
>
graph_p_
;
std
::
stack
<
FuncGraphPtr
>
graph_p_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
df_builder_
;
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.cc
浏览文件 @
2d973c95
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/optimizer/common/helper.h"
#include "ir/value.h"
#include "ir/value.h"
using
mindspore
::
kernel
::
Address
;
using
mindspore
::
kernel
::
Address
;
using
mindspore
::
kernel
::
AddressPtr
;
using
mindspore
::
kernel
::
AddressPtr
;
...
@@ -150,11 +151,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
...
@@ -150,11 +151,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
UpdateRefNodeOutputMem
(
graph
);
UpdateRefNodeOutputMem
(
graph
);
}
}
void
KernelRuntime
::
RunOpAssignMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
void
KernelRuntime
::
RunOpAssignMemory
(
const
ValuePtr
&
pre_output_value
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
session
::
KernelGraph
*
graph
)
{
session
::
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
RunOpAssignInputMemory
(
input_tensors
,
graph
);
RunOpAssignInputMemory
(
input_tensors
,
graph
);
AssignStaticMemoryValueNode
(
graph
);
AssignStaticMemoryValueNode
(
graph
);
RunOpAssignOutputNodeMemory
(
pre_output_value
,
graph
);
for
(
const
auto
&
cnode
:
graph
->
execution_order
())
{
for
(
const
auto
&
cnode
:
graph
->
execution_order
())
{
RunOpAssignOutputMemory
(
cnode
);
RunOpAssignOutputMemory
(
cnode
);
RunOpAssignWorkSpaceMemory
(
cnode
);
RunOpAssignWorkSpaceMemory
(
cnode
);
...
@@ -322,6 +325,45 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
...
@@ -322,6 +325,45 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
}
}
}
}
void
KernelRuntime
::
RunOpAssignOutputNodeMemory
(
const
ValuePtr
&
pre_output_value
,
session
::
KernelGraph
*
graph
)
{
if
(
pre_output_value
==
nullptr
)
{
return
;
}
std
::
vector
<
tensor
::
TensorPtr
>
pre_output_tensors
;
TensorValueToTensor
(
pre_output_value
,
&
pre_output_tensors
);
MS_EXCEPTION_IF_NULL
(
graph
);
auto
output_nodes
=
graph
->
outputs
();
if
(
pre_output_tensors
.
size
()
!=
output_nodes
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The size of pre output tensors ["
<<
pre_output_tensors
.
size
()
<<
"] is not equal to the size of output nodes of graph ["
<<
output_nodes
.
size
()
<<
"]"
;
}
// share output address with pre output tensors
for
(
size_t
i
=
0
;
i
<
output_nodes
.
size
();
++
i
)
{
auto
output_node_with_index
=
AnfAlgo
::
VisitKernel
(
output_nodes
[
i
],
0
);
if
(
!
output_node_with_index
.
first
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The output node should be a cnode , but it is "
<<
output_node_with_index
.
first
->
DebugString
();
}
auto
real_output_cnode
=
output_node_with_index
.
first
->
cast
<
CNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
real_output_cnode
);
MS_EXCEPTION_IF_NULL
(
pre_output_tensors
[
i
]);
if
(
pre_output_tensors
[
i
]
->
device_address
()
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"The address of pre output tensor ["
<<
i
<<
"] is a nullptr!"
;
}
if
(
opt
::
IsNopNode
(
real_output_cnode
))
{
if
(
real_output_cnode
->
inputs
().
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"The input size of output node: "
<<
real_output_cnode
->
DebugString
()
<<
" should large than one!"
;
}
AnfAlgo
::
SetOutputAddr
(
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
pre_output_tensors
[
i
]
->
device_address
()),
output_node_with_index
.
second
,
real_output_cnode
->
input
(
1
).
get
());
}
else
{
AnfAlgo
::
SetOutputAddr
(
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
pre_output_tensors
[
i
]
->
device_address
()),
output_node_with_index
.
second
,
output_node_with_index
.
first
.
get
());
}
}
}
void
KernelRuntime
::
AssignStaticMemoryInput
(
const
session
::
KernelGraph
*
graph
)
{
void
KernelRuntime
::
AssignStaticMemoryInput
(
const
session
::
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
graph
);
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
...
@@ -573,32 +615,40 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
...
@@ -573,32 +615,40 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
auto
ms_context
=
MsContext
::
GetInstance
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
MS_EXCEPTION_IF_NULL
(
ms_context
);
auto
tensor
=
node_value
->
cast
<
TensorPtr
>
();
std
::
vector
<
tensor
::
TensorPtr
>
tensors
;
if
(
tensor
==
nullptr
)
{
TensorValueToTensor
(
node_value
,
&
tensors
);
MS_LOG
(
WARNING
)
<<
"Tensor is null"
;
for
(
const
auto
&
tensor
:
tensors
)
{
return
;
if
(
tensor
==
nullptr
)
{
}
MS_LOG
(
WARNING
)
<<
"Tensor is null"
;
size_t
tensor_size
=
tensor
->
data
().
nbytes
();
return
;
auto
node_size
=
CountNodeDeviceMemorySize
(
value_node
,
output_idx
);
}
TypeId
output_type_id
=
AnfAlgo
::
GetOutputDeviceDataType
(
value_node
,
output_idx
);
if
(
tensor
->
device_address
()
!=
nullptr
)
{
if
(
output_type_id
==
kTypeUnknown
)
{
AnfAlgo
::
SetOutputAddr
(
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
tensor
->
device_address
()),
output_idx
++
,
output_type_id
=
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
output_idx
);
value_node
.
get
());
}
continue
;
auto
output_format
=
AnfAlgo
::
GetOutputFormat
(
value_node
,
output_idx
);
}
DeviceAddressPtr
address
=
nullptr
;
size_t
tensor_size
=
tensor
->
data
().
nbytes
();
address
=
CreateDeviceAddress
(
nullptr
,
node_size
,
output_format
,
output_type_id
);
auto
node_size
=
CountNodeDeviceMemorySize
(
value_node
,
output_idx
);
MS_EXCEPTION_IF_NULL
(
address
);
TypeId
output_type_id
=
AnfAlgo
::
GetOutputDeviceDataType
(
value_node
,
output_idx
);
if
(
ms_context
->
enable_pynative_infer
()
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
node_size
))
{
if
(
output_type_id
==
kTypeUnknown
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address from memory pool when tensor size is: "
<<
node_size
;
output_type_id
=
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
output_idx
);
}
else
if
(
mem_manager_
->
MallocMem
(
address
,
kStaticMem
,
node_size
)
==
nullptr
)
{
}
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
node_size
;
auto
output_format
=
AnfAlgo
::
GetOutputFormat
(
value_node
,
output_idx
);
}
DeviceAddressPtr
address
=
nullptr
;
AnfAlgo
::
SetOutputAddr
(
address
,
output_idx
,
value_node
.
get
());
address
=
CreateDeviceAddress
(
nullptr
,
node_size
,
output_format
,
output_type_id
);
if
(
!
address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
value_node
,
0
),
tensor_size
,
tensor
->
data_type
(),
MS_EXCEPTION_IF_NULL
(
address
);
tensor
->
data_c
()))
{
if
(
ms_context
->
enable_pynative_infer
()
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
node_size
))
{
MS_EXCEPTION
(
NotExistsError
)
<<
"ValueNode SyncHostToDevice fail!"
<<
value_node
->
DebugString
()
<<
"node format is"
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address from memory pool when tensor size is: "
<<
node_size
;
<<
AnfAlgo
::
GetOutputFormat
(
value_node
,
output_idx
)
<<
"node dtype is "
}
else
if
(
mem_manager_
->
MallocMem
(
address
,
kStaticMem
,
node_size
)
==
nullptr
)
{
<<
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
output_idx
);
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
node_size
;
}
AnfAlgo
::
SetOutputAddr
(
address
,
output_idx
,
value_node
.
get
());
if
(
!
address
->
SyncHostToDevice
(
trans
::
GetRuntimePaddingShape
(
value_node
,
0
),
tensor_size
,
tensor
->
data_type
(),
tensor
->
data_c
()))
{
MS_EXCEPTION
(
NotExistsError
)
<<
"ValueNode SyncHostToDevice fail!"
<<
value_node
->
DebugString
()
<<
"node format is"
<<
AnfAlgo
::
GetOutputFormat
(
value_node
,
output_idx
)
<<
"node dtype is "
<<
AnfAlgo
::
GetOutputInferDataType
(
value_node
,
output_idx
);
}
}
}
}
}
...
@@ -615,7 +665,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
...
@@ -615,7 +665,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
}
}
auto
&
node_value
=
value_node
->
value
();
auto
&
node_value
=
value_node
->
value
();
MS_EXCEPTION_IF_NULL
(
node_value
);
MS_EXCEPTION_IF_NULL
(
node_value
);
if
(
node_value
->
isa
<
Tensor
>
())
{
if
(
node_value
->
isa
<
Tensor
>
()
||
node_value
->
isa
<
ValueTuple
>
()
)
{
AssignValueNodeTensor
(
value_node
,
node_value
,
0
);
AssignValueNodeTensor
(
value_node
,
node_value
,
0
);
}
else
if
(
node_value
->
isa
<
StringImm
>
())
{
}
else
if
(
node_value
->
isa
<
StringImm
>
())
{
auto
value
=
GetValue
<
std
::
string
>
(
node_value
);
auto
value
=
GetValue
<
std
::
string
>
(
node_value
);
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.h
浏览文件 @
2d973c95
...
@@ -53,7 +53,8 @@ class KernelRuntime {
...
@@ -53,7 +53,8 @@ class KernelRuntime {
virtual
~
KernelRuntime
();
virtual
~
KernelRuntime
();
virtual
bool
Init
()
=
0
;
virtual
bool
Init
()
=
0
;
virtual
void
AssignMemory
(
session
::
KernelGraph
*
graph
);
virtual
void
AssignMemory
(
session
::
KernelGraph
*
graph
);
void
RunOpAssignMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
session
::
KernelGraph
*
graph
);
void
RunOpAssignMemory
(
const
ValuePtr
&
pre_output_value
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
session
::
KernelGraph
*
graph
);
void
RunOpClearMemory
(
const
session
::
KernelGraph
*
graph
);
void
RunOpClearMemory
(
const
session
::
KernelGraph
*
graph
);
bool
DumpDataEnabled
();
bool
DumpDataEnabled
();
bool
DumpDataEnabledIteration
();
bool
DumpDataEnabledIteration
();
...
@@ -108,6 +109,7 @@ class KernelRuntime {
...
@@ -108,6 +109,7 @@ class KernelRuntime {
void
RunOpAssignInputMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
session
::
KernelGraph
*
graph
);
void
RunOpAssignInputMemory
(
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
session
::
KernelGraph
*
graph
);
void
RunOpAssignOutputMemory
(
const
AnfNodePtr
&
kernel
);
void
RunOpAssignOutputMemory
(
const
AnfNodePtr
&
kernel
);
void
RunOpAssignWorkSpaceMemory
(
const
AnfNodePtr
&
kernel
);
void
RunOpAssignWorkSpaceMemory
(
const
AnfNodePtr
&
kernel
);
void
RunOpAssignOutputNodeMemory
(
const
ValuePtr
&
pre_output_value
,
session
::
KernelGraph
*
graph
);
void
AssignValueNodeTensor
(
const
ValueNodePtr
&
value_node
,
const
ValuePtr
&
node_value
,
size_t
output_idx
);
void
AssignValueNodeTensor
(
const
ValueNodePtr
&
value_node
,
const
ValuePtr
&
node_value
,
size_t
output_idx
);
DeviceAddressPtr
PreAssignCNodeMemory
(
const
AnfNodePtr
&
anf_node
,
size_t
index
);
DeviceAddressPtr
PreAssignCNodeMemory
(
const
AnfNodePtr
&
anf_node
,
size_t
index
);
...
...
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
2d973c95
...
@@ -607,4 +607,25 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
...
@@ -607,4 +607,25 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) {
MS_EXCEPTION_IF_NULL
(
tensor
);
MS_EXCEPTION_IF_NULL
(
tensor
);
return
tensor
;
return
tensor
;
}
}
void
TensorValueToTensor
(
const
ValuePtr
&
value
,
std
::
vector
<
tensor
::
TensorPtr
>
*
tensors
)
{
MS_EXCEPTION_IF_NULL
(
value
);
MS_EXCEPTION_IF_NULL
(
tensors
);
if
(
value
->
isa
<
ValueTuple
>
())
{
auto
value_tuple
=
value
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_tuple
);
for
(
size_t
i
=
0
;
i
<
value_tuple
->
size
();
++
i
)
{
ValuePtr
element
=
value_tuple
->
value
()[
i
];
if
(
element
->
isa
<
tensor
::
Tensor
>
())
{
auto
tensor
=
element
->
cast
<
tensor
::
TensorPtr
>
();
MS_EXCEPTION_IF_NULL
(
tensor
);
tensors
->
push_back
(
tensor
);
}
}
}
else
if
(
value
->
isa
<
tensor
::
Tensor
>
())
{
tensor
::
TensorPtr
tensor
=
value
->
cast
<
tensor
::
TensorPtr
>
();
MS_EXCEPTION_IF_NULL
(
tensor
);
tensors
->
push_back
(
tensor
);
}
}
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/utils/convert_utils.h
浏览文件 @
2d973c95
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
#include <memory>
#include <memory>
#include <utility>
#include <utility>
#include <stack>
#include <stack>
#include <vector>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
...
@@ -69,6 +70,8 @@ using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>;
...
@@ -69,6 +70,8 @@ using NodeMapEquiv = std::unordered_map<AnfNodePtr, AnfNodePtr>;
bool
Isomorphic
(
FuncGraphPtr
g1
,
FuncGraphPtr
g2
,
FuncGraphPairMapEquiv
*
equiv_func_graph
,
NodeMapEquiv
*
equiv_node
);
bool
Isomorphic
(
FuncGraphPtr
g1
,
FuncGraphPtr
g2
,
FuncGraphPairMapEquiv
*
equiv_func_graph
,
NodeMapEquiv
*
equiv_node
);
tensor
::
TensorPtr
ScalarToTensor
(
const
ScalarPtr
&
scalar
);
tensor
::
TensorPtr
ScalarToTensor
(
const
ScalarPtr
&
scalar
);
void
TensorValueToTensor
(
const
ValuePtr
&
value
,
std
::
vector
<
tensor
::
TensorPtr
>
*
tensors
);
}
// namespace mindspore
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
#endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
mindspore/core/ir/anf.h
浏览文件 @
2d973c95
...
@@ -50,8 +50,13 @@ using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
...
@@ -50,8 +50,13 @@ using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
using
AbstractBasePtr
=
std
::
shared_ptr
<
abstract
::
AbstractBase
>
;
using
AbstractBasePtr
=
std
::
shared_ptr
<
abstract
::
AbstractBase
>
;
using
AbstractBasePtrList
=
std
::
vector
<
AbstractBasePtr
>
;
using
AbstractBasePtrList
=
std
::
vector
<
AbstractBasePtr
>
;
class
Value
;
using
ValuePtr
=
std
::
shared_ptr
<
Value
>
;
using
ValuePtrList
=
std
::
vector
<
ValuePtr
>
;
class
ValueNode
;
class
ValueNode
;
using
ValueNodePtr
=
std
::
shared_ptr
<
ValueNode
>
;
using
ValueNodePtr
=
std
::
shared_ptr
<
ValueNode
>
;
class
CNode
;
class
CNode
;
using
CNodePtr
=
std
::
shared_ptr
<
CNode
>
;
using
CNodePtr
=
std
::
shared_ptr
<
CNode
>
;
...
@@ -225,6 +230,9 @@ class CNode : public AnfNode {
...
@@ -225,6 +230,9 @@ class CNode : public AnfNode {
void
set_input
(
size_t
i
,
const
AnfNodePtr
&
input
);
void
set_input
(
size_t
i
,
const
AnfNodePtr
&
input
);
void
set_inputs
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
inputs_
=
inputs
;
}
void
set_inputs
(
const
std
::
vector
<
AnfNodePtr
>
&
inputs
)
{
inputs_
=
inputs
;
}
void
set_forward
(
const
ValuePtr
&
forward
)
{
forward_
=
forward
;
}
const
ValuePtr
&
forward
()
const
{
return
forward_
;
}
bool
stop_gradient
()
const
{
return
stop_gradient_
;
}
bool
stop_gradient
()
const
{
return
stop_gradient_
;
}
void
set_stop_gradient
(
bool
stop_gradient
)
{
stop_gradient_
=
stop_gradient
;
}
void
set_stop_gradient
(
bool
stop_gradient
)
{
stop_gradient_
=
stop_gradient
;
}
...
@@ -243,6 +251,7 @@ class CNode : public AnfNode {
...
@@ -243,6 +251,7 @@ class CNode : public AnfNode {
VarPtr
func_graph_as_var_
;
VarPtr
func_graph_as_var_
;
bool
stop_gradient_
;
bool
stop_gradient_
;
bool
in_forward_flag_
=
false
;
bool
in_forward_flag_
=
false
;
ValuePtr
forward_
=
nullptr
;
};
};
// ANode represents the atomic node. It's derived Parameter and ValueNode.
// ANode represents the atomic node. It's derived Parameter and ValueNode.
...
@@ -321,8 +330,6 @@ class Value : public Base {
...
@@ -321,8 +330,6 @@ class Value : public Base {
protected:
protected:
TypePtr
type_
{
nullptr
};
TypePtr
type_
{
nullptr
};
};
};
using
ValuePtr
=
std
::
shared_ptr
<
Value
>
;
using
ValuePtrList
=
std
::
vector
<
ValuePtr
>
;
// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
// does not belong to any particular function graph.
// does not belong to any particular function graph.
...
@@ -333,9 +340,13 @@ class ValueNode : public ANode {
...
@@ -333,9 +340,13 @@ class ValueNode : public ANode {
MS_DECLARE_PARENT
(
ValueNode
,
ANode
);
MS_DECLARE_PARENT
(
ValueNode
,
ANode
);
void
accept
(
AnfIrVisitor
*
v
)
override
;
void
accept
(
AnfIrVisitor
*
v
)
override
;
void
set_value
(
const
ValuePtr
&
value
)
{
value_
=
value
;
}
const
ValuePtr
&
value
()
const
{
return
value_
;
}
const
ValuePtr
&
value
()
const
{
return
value_
;
}
std
::
string
fullname_with_scope
()
override
;
std
::
string
fullname_with_scope
()
override
;
void
set_has_new_value
(
bool
flag
)
{
has_new_value_
=
flag
;
}
bool
has_new_value
()
const
{
return
has_new_value_
;
}
std
::
string
ToString
()
const
override
;
std
::
string
ToString
()
const
override
;
std
::
string
DebugString
(
int
recursive_level
=
1
)
const
override
;
std
::
string
DebugString
(
int
recursive_level
=
1
)
const
override
;
std
::
string
DebugString
(
bool
recursive
)
const
override
{
return
DebugString
(
recursive
?
1
:
0
);
}
std
::
string
DebugString
(
bool
recursive
)
const
override
{
return
DebugString
(
recursive
?
1
:
0
);
}
...
@@ -355,6 +366,7 @@ class ValueNode : public ANode {
...
@@ -355,6 +366,7 @@ class ValueNode : public ANode {
private:
private:
ValuePtr
value_
;
ValuePtr
value_
;
bool
has_new_value_
=
false
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
mindspore/core/ir/func_graph_cloner.cc
浏览文件 @
2d973c95
...
@@ -88,6 +88,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
...
@@ -88,6 +88,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
CNodePtr
new_node
=
std
::
make_shared
<
CNode
>
(
AnfNodePtrList
{},
target
);
CNodePtr
new_node
=
std
::
make_shared
<
CNode
>
(
AnfNodePtrList
{},
target
);
auto
old_node
=
node
->
cast
<
CNodePtr
>
();
auto
old_node
=
node
->
cast
<
CNodePtr
>
();
new_node
->
set_abstract
(
old_node
->
abstract
());
new_node
->
set_abstract
(
old_node
->
abstract
());
new_node
->
set_forward
(
old_node
->
forward
());
ScopePtr
scope
=
(
node
->
scope
()
!=
kDefaultScope
)
?
node
->
scope
()
:
this
->
scope
();
ScopePtr
scope
=
(
node
->
scope
()
!=
kDefaultScope
)
?
node
->
scope
()
:
this
->
scope
();
new_node
->
set_scope
(
scope
);
new_node
->
set_scope
(
scope
);
new_node
->
set_kernel_info
(
old_node
->
kernel_info_ptr
());
new_node
->
set_kernel_info
(
old_node
->
kernel_info_ptr
());
...
@@ -103,6 +104,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
...
@@ -103,6 +104,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node) {
ScopePtr
scope
=
(
node
->
scope
()
!=
kDefaultScope
)
?
node
->
scope
()
:
this
->
scope
();
ScopePtr
scope
=
(
node
->
scope
()
!=
kDefaultScope
)
?
node
->
scope
()
:
this
->
scope
();
new_const
->
set_scope
(
scope
);
new_const
->
set_scope
(
scope
);
new_const
->
set_abstract
(
node
->
abstract
());
new_const
->
set_abstract
(
node
->
abstract
());
new_const
->
set_has_new_value
(
node
->
cast
<
ValueNodePtr
>
()
->
has_new_value
());
repl_node_
[
node
]
=
new_const
;
repl_node_
[
node
]
=
new_const
;
TraceManager
::
EndTrace
();
TraceManager
::
EndTrace
();
}
}
...
@@ -115,6 +117,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target)
...
@@ -115,6 +117,7 @@ void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target)
ScopePtr
scope
=
(
node
->
scope
()
!=
kDefaultScope
)
?
node
->
scope
()
:
this
->
scope
();
ScopePtr
scope
=
(
node
->
scope
()
!=
kDefaultScope
)
?
node
->
scope
()
:
this
->
scope
();
new_const
->
set_scope
(
scope
);
new_const
->
set_scope
(
scope
);
new_const
->
set_abstract
(
node
->
abstract
());
new_const
->
set_abstract
(
node
->
abstract
());
new_const
->
set_has_new_value
(
node
->
cast
<
ValueNodePtr
>
()
->
has_new_value
());
repl_node_
[
node
]
=
new_const
;
repl_node_
[
node
]
=
new_const
;
TraceManager
::
EndTrace
();
TraceManager
::
EndTrace
();
}
}
...
...
tests/ut/python/pynative_mode/test_high_order_grad.py
浏览文件 @
2d973c95
...
@@ -19,7 +19,7 @@ from mindspore.ops.composite import grad, grad_all, grad_all_with_sens
...
@@ -19,7 +19,7 @@ from mindspore.ops.composite import grad, grad_all, grad_all_with_sens
def
setup_module
(
module
):
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
check_bprop
=
False
)
def
single
(
x
):
def
single
(
x
):
...
...
tests/vm_impl/vm_me.py
浏览文件 @
2d973c95
...
@@ -554,9 +554,7 @@ def softmax_cross_entropy_with_logits(logits, labels):
...
@@ -554,9 +554,7 @@ def softmax_cross_entropy_with_logits(logits, labels):
sample_num
=
labels
.
shape
[
0
]
sample_num
=
labels
.
shape
[
0
]
prob
=
softmax
(
logits
)
prob
=
softmax
(
logits
)
log_likelihood
=
-
np
.
log
(
prob
[
range
(
sample_num
)])
*
labels
log_likelihood
=
-
np
.
log
(
prob
[
range
(
sample_num
)])
*
labels
# loss = np.sum(log_likelihood)
loss
=
np
.
sum
(
log_likelihood
)
loss
=
log_likelihood
dx
=
prob
.
copy
()
dx
=
prob
.
copy
()
dx
[
range
(
sample_num
)]
-=
labels
dx
[
range
(
sample_num
)]
-=
labels
return
loss
,
dx
return
loss
,
dx
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录