Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
02d6e3a4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02d6e3a4
编写于
8月 08, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
dc961e46
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
58 addition
and
24 deletion
+58
-24
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+2
-2
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+6
-0
mindspore/ccsrc/frontend/operator/composite/composite.cc
mindspore/ccsrc/frontend/operator/composite/composite.cc
+3
-3
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
+1
-0
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+3
-3
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+25
-10
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
+2
-1
mindspore/ccsrc/utils/primitive_py.cc
mindspore/ccsrc/utils/primitive_py.cc
+9
-2
mindspore/common/tensor.py
mindspore/common/tensor.py
+6
-2
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+1
-1
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
02d6e3a4
...
@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
...
@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
get_parse_method_of_class
,
get_scope_name
,
is_class_member
,
parse_cb
,
resolve_symbol
)
is_class_member
,
parse_cb
,
resolve_symbol
,
convert_to_ms_tensor
)
from
.serialize
import
*
from
.serialize
import
*
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
...
@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
...
@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
]
'create_slice_obj'
,
'convert_to_ms_tensor'
]
mindspore/_extends/parse/parser.py
浏览文件 @
02d6e3a4
...
@@ -25,6 +25,7 @@ from dataclasses import is_dataclass
...
@@ -25,6 +25,7 @@ from dataclasses import is_dataclass
import
asttokens
import
asttokens
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore
import
Tensor
as
MsTensor
from
mindspore
import
ops
from
mindspore
import
ops
from
mindspore.common.dtype
import
pytype_to_dtype
from
mindspore.common.dtype
import
pytype_to_dtype
from
mindspore.common.api
import
_MindSporeFunction
from
mindspore.common.api
import
_MindSporeFunction
...
@@ -316,6 +317,11 @@ def get_dataclass_methods(cls):
...
@@ -316,6 +317,11 @@ def get_dataclass_methods(cls):
return
methods
return
methods
def
convert_to_ms_tensor
(
data
):
"""Convert C++ tensor to mindspore tensor."""
return
MsTensor
(
data
)
class
Parser
:
class
Parser
:
"""
"""
Parser python code to ast tree.
Parser python code to ast tree.
...
...
mindspore/ccsrc/frontend/operator/composite/composite.cc
浏览文件 @
02d6e3a4
...
@@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
...
@@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*
step_value
=
CheckSliceMember
(
slice
->
step
(),
step_default
,
step_name
);
*
step_value
=
CheckSliceMember
(
slice
->
step
(),
step_default
,
step_name
);
if
(
*
step_value
==
0
)
{
if
(
*
step_value
==
0
)
{
MS_
LOG
(
EXCEPTION
)
<<
"TupleSlice require the step value could not be 0, but got 0."
;
MS_
EXCEPTION
(
ValueError
)
<<
"TupleSlice require the step value could not be 0, but got 0."
;
}
}
if
(
*
step_value
<
0
)
{
if
(
*
step_value
<
0
)
{
...
@@ -941,8 +941,8 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
...
@@ -941,8 +941,8 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*
stop_index
=
CheckSliceMember
(
slice
->
stop
(),
stop_default
,
stop_name
);
*
stop_index
=
CheckSliceMember
(
slice
->
stop
(),
stop_default
,
stop_name
);
if
(
!
CheckIndexInRange
(
*
start_index
,
-
tuple_size
,
tuple_size
-
1
)
||
if
(
!
CheckIndexInRange
(
*
start_index
,
-
tuple_size
,
tuple_size
-
1
)
||
!
CheckIndexInRange
(
*
stop_index
,
-
tuple_size
-
1
,
tuple_size
))
{
!
CheckIndexInRange
(
*
stop_index
,
-
tuple_size
-
1
,
tuple_size
))
{
MS_
LOG
(
EXCEPTION
)
<<
"TupleSlice the start index "
<<
*
start_index
<<
" or end end index "
<<
*
stop_index
MS_
EXCEPTION
(
ValueError
)
<<
"TupleSlice the start index "
<<
*
start_index
<<
" or end end index "
<<
*
stop_index
<<
" out of range, tuple size "
<<
tuple_size
<<
"."
;
<<
" out of range, tuple size "
<<
tuple_size
<<
"."
;
}
}
*
start_index
=
GetPositiveIndex
(
*
start_index
,
tuple_size
);
*
start_index
=
GetPositiveIndex
(
*
start_index
,
tuple_size
);
...
...
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
浏览文件 @
02d6e3a4
...
@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
...
@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const
char
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
[]
=
"get_class_member_namespace_symbol"
;
const
char
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
[]
=
"get_class_member_namespace_symbol"
;
const
char
PYTHON_MOD_GET_PARSE_METHOD
[]
=
"get_parse_method_of_class"
;
const
char
PYTHON_MOD_GET_PARSE_METHOD
[]
=
"get_parse_method_of_class"
;
const
char
PYTHON_MOD_GET_BPROP_METHOD
[]
=
"get_bprop_method_of_class"
;
const
char
PYTHON_MOD_GET_BPROP_METHOD
[]
=
"get_bprop_method_of_class"
;
const
char
PYTHON_MOD_CONVERT_TO_MS_TENSOR
[]
=
"convert_to_ms_tensor"
;
const
char
PYTHON_PARSE_GET_ARGS
[]
=
"get_args"
;
const
char
PYTHON_PARSE_GET_ARGS
[]
=
"get_args"
;
const
char
PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES
[]
=
"get_args_default_values"
;
const
char
PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES
[]
=
"get_args_default_values"
;
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
02d6e3a4
...
@@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
...
@@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
for
(
size_t
index
=
0
;
index
<
specialize_args_before_unpack
.
size
();
index
++
)
{
for
(
size_t
index
=
0
;
index
<
specialize_args_before_unpack
.
size
();
index
++
)
{
MS_EXCEPTION_IF_NULL
(
specialize_args_before_unpack
[
index
]);
MS_EXCEPTION_IF_NULL
(
specialize_args_before_unpack
[
index
]);
if
(
specialize_args_before_unpack
[
index
]
->
isa
<
AbstractTuple
>
())
{
if
(
specialize_args_before_unpack
[
index
]
->
isa
<
AbstractTuple
>
())
{
AbstractTuplePtr
arg_tuple
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractTuplePtr
>
();
auto
arg_tuple
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractTuplePtr
>
();
std
::
transform
(
arg_tuple
->
elements
().
begin
(),
arg_tuple
->
elements
().
end
(),
std
::
transform
(
arg_tuple
->
elements
().
begin
(),
arg_tuple
->
elements
().
end
(),
std
::
back_inserter
(
graph_specialize_args
),
[](
AbstractBasePtr
abs
)
{
return
abs
;
});
std
::
back_inserter
(
graph_specialize_args
),
[](
AbstractBasePtr
abs
)
{
return
abs
;
});
}
else
if
(
specialize_args_before_unpack
[
index
]
->
isa
<
AbstractDictionary
>
())
{
}
else
if
(
specialize_args_before_unpack
[
index
]
->
isa
<
AbstractDictionary
>
())
{
AbstractDictionaryPtr
arg_dict
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractDictionaryPtr
>
();
auto
arg_dict
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractDictionaryPtr
>
();
auto
dict_elems
=
arg_dict
->
elements
();
auto
dict_elems
=
arg_dict
->
elements
();
(
void
)
std
::
transform
(
(
void
)
std
::
transform
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
std
::
back_inserter
(
graph_specialize_args
),
dict_elems
.
begin
(),
dict_elems
.
end
(),
std
::
back_inserter
(
graph_specialize_args
),
...
@@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
...
@@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
}
}
auto
out_node
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
auto
out_node
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
const
auto
&
out_node_inputs
=
out_node
->
inputs
();
const
auto
&
out_node_inputs
=
out_node
->
inputs
();
if
(
out_node
->
inputs
().
size
()
==
0
||
(
out_node_inputs
.
size
()
-
1
)
!=
args_conf_list
.
size
())
{
if
(
out_node
->
inputs
().
empty
()
||
(
out_node_inputs
.
size
()
-
1
)
!=
args_conf_list
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"MixedPrecisionCast"
MS_LOG
(
EXCEPTION
)
<<
"MixedPrecisionCast"
<<
" args size should equal to inputs size minus 1, but args size "
<<
args_conf_list
.
size
()
<<
" args size should equal to inputs size minus 1, but args size "
<<
args_conf_list
.
size
()
<<
", inputs size "
<<
out_node_inputs
.
size
();
<<
", inputs size "
<<
out_node_inputs
.
size
();
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
02d6e3a4
...
@@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
...
@@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
static
std
::
string
GetId
(
const
py
::
object
&
obj
)
{
static
std
::
string
GetId
(
const
py
::
object
&
obj
)
{
py
::
object
to_process
=
obj
;
py
::
object
to_process
=
obj
;
std
::
string
prefix
=
""
;
std
::
string
prefix
=
""
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
to_process
))
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
to_process
)
||
py
::
isinstance
<
py
::
list
>
(
to_process
)
)
{
auto
p_list
=
py
::
cast
<
py
::
tuple
>
(
to_process
);
auto
p_list
=
py
::
cast
<
py
::
tuple
>
(
to_process
);
if
(
p_list
.
size
()
==
0
)
{
if
(
p_list
.
empty
()
)
{
return
"empty"
;
return
"empty"
;
}
}
prefix
=
"tuple:
"
;
prefix
=
py
::
isinstance
<
py
::
tuple
>
(
to_process
)
?
"tuple:"
:
"list
"
;
std
::
string
key
=
""
;
std
::
string
key
=
""
;
for
(
size_t
i
=
0
;
i
<
p_list
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
p_list
.
size
();
++
i
)
{
key
+=
std
::
string
(
py
::
str
(
GetId
(
p_list
[
i
])))
+
":"
;
key
+=
std
::
string
(
py
::
str
(
GetId
(
p_list
[
i
])))
+
":"
;
...
@@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
...
@@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return
node
;
return
node
;
}
}
std
::
string
PynativeExecutor
::
GetCellId
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
)
{
auto
cell_id
=
GetId
(
cell
);
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
std
::
string
arg_id
=
GetId
(
args
[
i
]);
if
(
node_abs_map_
.
find
(
arg_id
)
!=
node_abs_map_
.
end
())
{
cell_id
+=
node_abs_map_
[
arg_id
]
->
ToString
();
}
else
{
AbstractBasePtr
abs
=
abstract
::
FromValueInside
(
PyAttrValue
(
args
[
i
]),
true
);
cell_id
+=
abs
->
ToString
();
node_abs_map_
[
arg_id
]
=
abs
;
}
}
return
cell_id
;
}
py
::
tuple
PynativeExecutor
::
RunOpInner
(
const
OpExecInfoPtr
&
op_exec_info
)
{
py
::
tuple
PynativeExecutor
::
RunOpInner
(
const
OpExecInfoPtr
&
op_exec_info
)
{
MS_LOG
(
INFO
)
<<
"RunOp start, op name is: "
<<
op_exec_info
->
op_name
;
MS_LOG
(
INFO
)
<<
"RunOp start, op name is: "
<<
op_exec_info
->
op_name
;
mindspore
::
parse
::
python_adapter
::
set_python_env_flag
(
true
);
mindspore
::
parse
::
python_adapter
::
set_python_env_flag
(
true
);
...
@@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
...
@@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
}
}
auto
cnode
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
op_exec_info
,
&
op_masks
,
&
args_spec_list
);
auto
cnode
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
op_exec_info
,
&
op_masks
,
&
args_spec_list
);
bool
is_find
=
false
;
bool
is_find
=
false
;
if
(
prim_abs_list
.
find
(
prim
->
id
())
!=
prim_abs_list
.
end
())
{
if
(
prim_abs_list
_
.
find
(
prim
->
id
())
!=
prim_abs_list_
.
end
())
{
auto
abs_list
=
prim_abs_list
[
prim
->
id
()];
auto
abs_list
=
prim_abs_list
_
[
prim
->
id
()];
MS_LOG
(
DEBUG
)
<<
"match prim input args "
<<
op_exec_info
->
op_name
<<
mindspore
::
ToString
(
args_spec_list
);
MS_LOG
(
DEBUG
)
<<
"match prim input args "
<<
op_exec_info
->
op_name
<<
mindspore
::
ToString
(
args_spec_list
);
if
(
abs_list
.
find
(
args_spec_list
)
!=
abs_list
.
end
())
{
if
(
abs_list
.
find
(
args_spec_list
)
!=
abs_list
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"match prim ok"
<<
op_exec_info
->
op_name
;
MS_LOG
(
DEBUG
)
<<
"match prim ok"
<<
op_exec_info
->
op_name
;
...
@@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
...
@@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
if
(
!
is_find
)
{
if
(
!
is_find
)
{
// const_value need infer every step
// const_value need infer every step
auto
&
out
=
prim_abs_list
[
prim
->
id
()];
auto
&
out
=
prim_abs_list
_
[
prim
->
id
()];
out
[
args_spec_list
].
abs
=
op_exec_info
->
abstract
;
out
[
args_spec_list
].
abs
=
op_exec_info
->
abstract
;
out
[
args_spec_list
].
attrs
=
prim
->
evaluate_added_attrs
();
out
[
args_spec_list
].
attrs
=
prim
->
evaluate_added_attrs
();
MS_LOG
(
DEBUG
)
<<
"set prim "
<<
op_exec_info
->
op_name
<<
mindspore
::
ToString
(
args_spec_list
);
MS_LOG
(
DEBUG
)
<<
"set prim "
<<
op_exec_info
->
op_name
<<
mindspore
::
ToString
(
args_spec_list
);
...
@@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
...
@@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor
::
PynativeExecutor
()
{
grad_flag_
=
false
;
}
PynativeExecutor
::
PynativeExecutor
()
{
grad_flag_
=
false
;
}
void
PynativeExecutor
::
NewGraphInner
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
)
{
void
PynativeExecutor
::
NewGraphInner
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
)
{
auto
cell_id
=
Get
Id
(
cell
);
auto
cell_id
=
Get
CellId
(
cell
,
args
);
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
if
(
cell_resource_map_
.
find
(
cell_id
)
!=
cell_resource_map_
.
end
())
{
if
(
cell_resource_map_
.
find
(
cell_id
)
!=
cell_resource_map_
.
end
())
{
resource_
=
cell_resource_map_
[
cell_id
];
resource_
=
cell_resource_map_
[
cell_id
];
...
@@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() {
...
@@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() {
}
}
void
PynativeExecutor
::
EndGraphInner
(
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
)
{
void
PynativeExecutor
::
EndGraphInner
(
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
)
{
auto
cell_id
=
Get
Id
(
cell
);
auto
cell_id
=
Get
CellId
(
cell
,
args
);
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"Endgraph already compiled"
;
MS_LOG
(
DEBUG
)
<<
"Endgraph already compiled"
;
return
;
return
;
...
@@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
...
@@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
inputs
.
push_back
(
input
);
inputs
.
push_back
(
input
);
}
}
auto
out_cnode
=
curr_g_
->
NewCNode
(
inputs
);
auto
out_cnode
=
curr_g_
->
NewCNode
(
inputs
);
set_pyobj
(
curr_g_
,
Get
Id
(
cell
));
set_pyobj
(
curr_g_
,
Get
CellId
(
cell
,
args
));
if
(
py
::
isinstance
<
py
::
tuple
>
(
out
))
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
out
))
{
auto
out_list
=
py
::
cast
<
py
::
tuple
>
(
out
);
auto
out_list
=
py
::
cast
<
py
::
tuple
>
(
out
);
auto
out_size
=
static_cast
<
int
>
(
out_list
.
size
());
auto
out_size
=
static_cast
<
int
>
(
out_list
.
size
());
...
@@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
...
@@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
MS_LOG
(
INFO
)
<<
"GradNet start"
<<
args
.
size
();
MS_LOG
(
INFO
)
<<
"GradNet start"
<<
args
.
size
();
std
::
size_t
size
=
args
.
size
();
std
::
size_t
size
=
args
.
size
();
auto
cell_id
=
GetId
(
cell
);
std
::
string
cell_id
=
GetCellId
(
cell
,
args
);
if
(
graph_map_
.
count
(
cell_id
)
!=
0
)
{
if
(
graph_map_
.
count
(
cell_id
)
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"GradNet already compiled"
;
MS_LOG
(
DEBUG
)
<<
"GradNet already compiled"
;
return
;
return
;
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
浏览文件 @
02d6e3a4
...
@@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
bool
op_mask
);
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
bool
op_mask
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
std
::
string
GetCellId
(
const
py
::
object
&
obj
,
const
py
::
args
&
args
);
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
)
{
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
)
{
...
@@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
...
@@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr
top_g_
;
FuncGraphPtr
top_g_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
curr_g_
;
FuncGraphPtr
curr_g_
;
std
::
unordered_map
<
std
::
string
,
AbstractListMap
>
prim_abs_list
;
std
::
unordered_map
<
std
::
string
,
AbstractListMap
>
prim_abs_list
_
;
};
};
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
...
...
mindspore/ccsrc/utils/primitive_py.cc
浏览文件 @
02d6e3a4
...
@@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() {
...
@@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() {
}
}
BaseRef
PrimitivePy
::
RunHookFunction
(
const
VectorRef
&
args
)
const
{
BaseRef
PrimitivePy
::
RunHookFunction
(
const
VectorRef
&
args
)
const
{
auto
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
tuple
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
object
obj
;
py
::
object
obj
;
bool
is_bprop
=
this
->
HasAttr
(
kBpropAttrName
);
bool
is_bprop
=
this
->
HasAttr
(
kBpropAttrName
);
if
(
is_bprop
)
{
if
(
is_bprop
)
{
SyncData
(
py_args
);
SyncData
(
py_args
);
obj
=
hook_
(
*
py_args
);
py
::
tuple
convert_args
(
py_args
.
size
());
for
(
size_t
i
=
0
;
i
<
py_args
.
size
();
i
++
)
{
convert_args
[
i
]
=
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
i
])
?
parse
::
python_adapter
::
CallPyFn
(
parse
::
PYTHON_MOD_PARSE_MODULE
,
parse
::
PYTHON_MOD_CONVERT_TO_MS_TENSOR
,
py_args
[
i
])
:
py_args
[
i
];
}
obj
=
hook_
(
*
convert_args
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
}
}
SyncData
(
py_args
[
2
]);
SyncData
(
py_args
[
2
]);
...
...
mindspore/common/tensor.py
浏览文件 @
02d6e3a4
...
@@ -210,12 +210,12 @@ class Tensor(Tensor_):
...
@@ -210,12 +210,12 @@ class Tensor(Tensor_):
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
"""The shape of tensor."""
"""The shape of tensor
is a tuple
."""
return
self
.
_shape
return
self
.
_shape
@
property
@
property
def
dtype
(
self
):
def
dtype
(
self
):
"""The dtype of tensor."""
"""The dtype of tensor
is a mindspore type
."""
return
self
.
_dtype
return
self
.
_dtype
@
property
@
property
...
@@ -248,6 +248,8 @@ class Tensor(Tensor_):
...
@@ -248,6 +248,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
Tensor, has the same data type as x.
"""
"""
if
axis
is
None
:
axis
=
()
return
tensor_operator_registry
.
get
(
'all'
)(
keep_dims
)(
self
,
axis
)
return
tensor_operator_registry
.
get
(
'all'
)(
keep_dims
)(
self
,
axis
)
def
any
(
self
,
axis
=
(),
keep_dims
=
False
):
def
any
(
self
,
axis
=
(),
keep_dims
=
False
):
...
@@ -264,6 +266,8 @@ class Tensor(Tensor_):
...
@@ -264,6 +266,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
Tensor, has the same data type as x.
"""
"""
if
axis
is
None
:
axis
=
()
return
tensor_operator_registry
.
get
(
'any'
)(
keep_dims
)(
self
,
axis
)
return
tensor_operator_registry
.
get
(
'any'
)(
keep_dims
)(
self
,
axis
)
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
02d6e3a4
...
@@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self):
...
@@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self):
select
=
P
.
Select
()
select
=
P
.
Select
()
def
bprop
(
x
,
segment_ids
,
num_segments
,
out
,
dout
):
def
bprop
(
x
,
segment_ids
,
num_segments
,
out
,
dout
):
gathered_outputs
,
zero_clipped_indices
,
is_positive
=
_GatherDropNegatives
(
out
,
segment_ids
)
gathered_outputs
,
zero_clipped_indices
,
is_positive
=
_GatherDropNegatives
(
out
,
segment_ids
,
None
,
None
)
is_selected
=
equal
(
x
,
gathered_outputs
)
is_selected
=
equal
(
x
,
gathered_outputs
)
is_selected
=
logical_and
(
is_selected
,
is_positive
)
is_selected
=
logical_and
(
is_selected
,
is_positive
)
num_selected
=
unsorted_segment_sum
(
cast
(
is_selected
,
get_dtype
(
dout
)),
num_selected
=
unsorted_segment_sum
(
cast
(
is_selected
,
get_dtype
(
dout
)),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录