Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
02d6e3a4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02d6e3a4
编写于
8月 08, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
dc961e46
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
58 addition
and
24 deletion
+58
-24
mindspore/_extends/parse/__init__.py
mindspore/_extends/parse/__init__.py
+2
-2
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+6
-0
mindspore/ccsrc/frontend/operator/composite/composite.cc
mindspore/ccsrc/frontend/operator/composite/composite.cc
+3
-3
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
+1
-0
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+3
-3
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+25
-10
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
+2
-1
mindspore/ccsrc/utils/primitive_py.cc
mindspore/ccsrc/utils/primitive_py.cc
+9
-2
mindspore/common/tensor.py
mindspore/common/tensor.py
+6
-2
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+1
-1
未找到文件。
mindspore/_extends/parse/__init__.py
浏览文件 @
02d6e3a4
...
...
@@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_dataclass_attributes
,
get_dataclass_methods
,
get_obj_id
,
get_module_namespace
,
get_obj_type
,
get_object_key
,
get_parse_method_of_class
,
get_scope_name
,
is_class_member
,
parse_cb
,
resolve_symbol
)
is_class_member
,
parse_cb
,
resolve_symbol
,
convert_to_ms_tensor
)
from
.serialize
import
*
__all__
=
[
'parse_cb'
,
'get_parse_method_of_class'
,
'get_bprop_method_of_class'
,
'resolve_symbol'
,
...
...
@@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class',
'get_obj_type'
,
'get_obj_id'
,
'create_obj_instance'
,
'get_module_namespace'
,
'get_class_member_namespace_symbol'
,
'get_obj_id'
,
'Parser'
,
'get_dataclass_attributes'
,
'get_dataclass_methods'
,
'dump_obj'
,
'load_obj'
,
'get_dataclass_methods'
,
'get_scope_name'
,
'create_slice_obj'
]
'create_slice_obj'
,
'convert_to_ms_tensor'
]
mindspore/_extends/parse/parser.py
浏览文件 @
02d6e3a4
...
...
@@ -25,6 +25,7 @@ from dataclasses import is_dataclass
import
asttokens
import
mindspore.nn
as
nn
from
mindspore
import
log
as
logger
from
mindspore
import
Tensor
as
MsTensor
from
mindspore
import
ops
from
mindspore.common.dtype
import
pytype_to_dtype
from
mindspore.common.api
import
_MindSporeFunction
...
...
@@ -316,6 +317,11 @@ def get_dataclass_methods(cls):
return
methods
def
convert_to_ms_tensor
(
data
):
"""Convert C++ tensor to mindspore tensor."""
return
MsTensor
(
data
)
class
Parser
:
"""
Parser python code to ast tree.
...
...
mindspore/ccsrc/frontend/operator/composite/composite.cc
浏览文件 @
02d6e3a4
...
...
@@ -929,7 +929,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*
step_value
=
CheckSliceMember
(
slice
->
step
(),
step_default
,
step_name
);
if
(
*
step_value
==
0
)
{
MS_
LOG
(
EXCEPTION
)
<<
"TupleSlice require the step value could not be 0, but got 0."
;
MS_
EXCEPTION
(
ValueError
)
<<
"TupleSlice require the step value could not be 0, but got 0."
;
}
if
(
*
step_value
<
0
)
{
...
...
@@ -941,8 +941,8 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
*
stop_index
=
CheckSliceMember
(
slice
->
stop
(),
stop_default
,
stop_name
);
if
(
!
CheckIndexInRange
(
*
start_index
,
-
tuple_size
,
tuple_size
-
1
)
||
!
CheckIndexInRange
(
*
stop_index
,
-
tuple_size
-
1
,
tuple_size
))
{
MS_
LOG
(
EXCEPTION
)
<<
"TupleSlice the start index "
<<
*
start_index
<<
" or end end index "
<<
*
stop_index
<<
" out of range, tuple size "
<<
tuple_size
<<
"."
;
MS_
EXCEPTION
(
ValueError
)
<<
"TupleSlice the start index "
<<
*
start_index
<<
" or end end index "
<<
*
stop_index
<<
" out of range, tuple size "
<<
tuple_size
<<
"."
;
}
*
start_index
=
GetPositiveIndex
(
*
start_index
,
tuple_size
);
...
...
mindspore/ccsrc/pipeline/jit/parse/parse_base.h
浏览文件 @
02d6e3a4
...
...
@@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace";
const
char
PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL
[]
=
"get_class_member_namespace_symbol"
;
const
char
PYTHON_MOD_GET_PARSE_METHOD
[]
=
"get_parse_method_of_class"
;
const
char
PYTHON_MOD_GET_BPROP_METHOD
[]
=
"get_bprop_method_of_class"
;
const
char
PYTHON_MOD_CONVERT_TO_MS_TENSOR
[]
=
"convert_to_ms_tensor"
;
const
char
PYTHON_PARSE_GET_ARGS
[]
=
"get_args"
;
const
char
PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES
[]
=
"get_args_default_values"
;
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
02d6e3a4
...
...
@@ -226,11 +226,11 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
for
(
size_t
index
=
0
;
index
<
specialize_args_before_unpack
.
size
();
index
++
)
{
MS_EXCEPTION_IF_NULL
(
specialize_args_before_unpack
[
index
]);
if
(
specialize_args_before_unpack
[
index
]
->
isa
<
AbstractTuple
>
())
{
AbstractTuplePtr
arg_tuple
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractTuplePtr
>
();
auto
arg_tuple
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractTuplePtr
>
();
std
::
transform
(
arg_tuple
->
elements
().
begin
(),
arg_tuple
->
elements
().
end
(),
std
::
back_inserter
(
graph_specialize_args
),
[](
AbstractBasePtr
abs
)
{
return
abs
;
});
}
else
if
(
specialize_args_before_unpack
[
index
]
->
isa
<
AbstractDictionary
>
())
{
AbstractDictionaryPtr
arg_dict
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractDictionaryPtr
>
();
auto
arg_dict
=
specialize_args_before_unpack
[
index
]
->
cast
<
AbstractDictionaryPtr
>
();
auto
dict_elems
=
arg_dict
->
elements
();
(
void
)
std
::
transform
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
std
::
back_inserter
(
graph_specialize_args
),
...
...
@@ -353,7 +353,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
}
auto
out_node
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
const
auto
&
out_node_inputs
=
out_node
->
inputs
();
if
(
out_node
->
inputs
().
size
()
==
0
||
(
out_node_inputs
.
size
()
-
1
)
!=
args_conf_list
.
size
())
{
if
(
out_node
->
inputs
().
empty
()
||
(
out_node_inputs
.
size
()
-
1
)
!=
args_conf_list
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"MixedPrecisionCast"
<<
" args size should equal to inputs size minus 1, but args size "
<<
args_conf_list
.
size
()
<<
", inputs size "
<<
out_node_inputs
.
size
();
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
02d6e3a4
...
...
@@ -115,12 +115,12 @@ inline ValuePtr PyAttrValue(const py::object &obj) {
static
std
::
string
GetId
(
const
py
::
object
&
obj
)
{
py
::
object
to_process
=
obj
;
std
::
string
prefix
=
""
;
if
(
py
::
isinstance
<
py
::
tuple
>
(
to_process
))
{
if
(
py
::
isinstance
<
py
::
tuple
>
(
to_process
)
||
py
::
isinstance
<
py
::
list
>
(
to_process
)
)
{
auto
p_list
=
py
::
cast
<
py
::
tuple
>
(
to_process
);
if
(
p_list
.
size
()
==
0
)
{
if
(
p_list
.
empty
()
)
{
return
"empty"
;
}
prefix
=
"tuple:
"
;
prefix
=
py
::
isinstance
<
py
::
tuple
>
(
to_process
)
?
"tuple:"
:
"list
"
;
std
::
string
key
=
""
;
for
(
size_t
i
=
0
;
i
<
p_list
.
size
();
++
i
)
{
key
+=
std
::
string
(
py
::
str
(
GetId
(
p_list
[
i
])))
+
":"
;
...
...
@@ -738,6 +738,21 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
return
node
;
}
std
::
string
PynativeExecutor
::
GetCellId
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
)
{
auto
cell_id
=
GetId
(
cell
);
for
(
size_t
i
=
0
;
i
<
args
.
size
();
i
++
)
{
std
::
string
arg_id
=
GetId
(
args
[
i
]);
if
(
node_abs_map_
.
find
(
arg_id
)
!=
node_abs_map_
.
end
())
{
cell_id
+=
node_abs_map_
[
arg_id
]
->
ToString
();
}
else
{
AbstractBasePtr
abs
=
abstract
::
FromValueInside
(
PyAttrValue
(
args
[
i
]),
true
);
cell_id
+=
abs
->
ToString
();
node_abs_map_
[
arg_id
]
=
abs
;
}
}
return
cell_id
;
}
py
::
tuple
PynativeExecutor
::
RunOpInner
(
const
OpExecInfoPtr
&
op_exec_info
)
{
MS_LOG
(
INFO
)
<<
"RunOp start, op name is: "
<<
op_exec_info
->
op_name
;
mindspore
::
parse
::
python_adapter
::
set_python_env_flag
(
true
);
...
...
@@ -785,8 +800,8 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
}
auto
cnode
=
PynativeExecutor
::
GetInstance
()
->
MakeCNode
(
op_exec_info
,
&
op_masks
,
&
args_spec_list
);
bool
is_find
=
false
;
if
(
prim_abs_list
.
find
(
prim
->
id
())
!=
prim_abs_list
.
end
())
{
auto
abs_list
=
prim_abs_list
[
prim
->
id
()];
if
(
prim_abs_list
_
.
find
(
prim
->
id
())
!=
prim_abs_list_
.
end
())
{
auto
abs_list
=
prim_abs_list
_
[
prim
->
id
()];
MS_LOG
(
DEBUG
)
<<
"match prim input args "
<<
op_exec_info
->
op_name
<<
mindspore
::
ToString
(
args_spec_list
);
if
(
abs_list
.
find
(
args_spec_list
)
!=
abs_list
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"match prim ok"
<<
op_exec_info
->
op_name
;
...
...
@@ -827,7 +842,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
if
(
!
is_find
)
{
// const_value need infer every step
auto
&
out
=
prim_abs_list
[
prim
->
id
()];
auto
&
out
=
prim_abs_list
_
[
prim
->
id
()];
out
[
args_spec_list
].
abs
=
op_exec_info
->
abstract
;
out
[
args_spec_list
].
attrs
=
prim
->
evaluate_added_attrs
();
MS_LOG
(
DEBUG
)
<<
"set prim "
<<
op_exec_info
->
op_name
<<
mindspore
::
ToString
(
args_spec_list
);
...
...
@@ -890,7 +905,7 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); }
PynativeExecutor
::
PynativeExecutor
()
{
grad_flag_
=
false
;
}
void
PynativeExecutor
::
NewGraphInner
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
)
{
auto
cell_id
=
Get
Id
(
cell
);
auto
cell_id
=
Get
CellId
(
cell
,
args
);
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
if
(
cell_resource_map_
.
find
(
cell_id
)
!=
cell_resource_map_
.
end
())
{
resource_
=
cell_resource_map_
[
cell_id
];
...
...
@@ -1016,7 +1031,7 @@ void PynativeExecutor::Popp() {
}
void
PynativeExecutor
::
EndGraphInner
(
const
py
::
object
&
cell
,
const
py
::
object
&
out
,
const
py
::
args
&
args
)
{
auto
cell_id
=
Get
Id
(
cell
);
auto
cell_id
=
Get
CellId
(
cell
,
args
);
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"Endgraph already compiled"
;
return
;
...
...
@@ -1078,7 +1093,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
inputs
.
push_back
(
input
);
}
auto
out_cnode
=
curr_g_
->
NewCNode
(
inputs
);
set_pyobj
(
curr_g_
,
Get
Id
(
cell
));
set_pyobj
(
curr_g_
,
Get
CellId
(
cell
,
args
));
if
(
py
::
isinstance
<
py
::
tuple
>
(
out
))
{
auto
out_list
=
py
::
cast
<
py
::
tuple
>
(
out
);
auto
out_size
=
static_cast
<
int
>
(
out_list
.
size
());
...
...
@@ -1169,7 +1184,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
MS_LOG
(
INFO
)
<<
"GradNet start"
<<
args
.
size
();
std
::
size_t
size
=
args
.
size
();
auto
cell_id
=
GetId
(
cell
);
std
::
string
cell_id
=
GetCellId
(
cell
,
args
);
if
(
graph_map_
.
count
(
cell_id
)
!=
0
)
{
MS_LOG
(
DEBUG
)
<<
"GradNet already compiled"
;
return
;
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
浏览文件 @
02d6e3a4
...
...
@@ -92,6 +92,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void
set_grad_flag
(
bool
flag
)
{
grad_flag_
=
flag
;
}
AnfNodePtr
GetInput
(
const
py
::
object
&
obj
,
bool
op_mask
);
AnfNodePtr
GetObjNode
(
const
py
::
object
&
obj
);
std
::
string
GetCellId
(
const
py
::
object
&
obj
,
const
py
::
args
&
args
);
FuncGraphPtr
curr_g
()
{
return
curr_g_
;
}
void
set_pyobj
(
FuncGraphPtr
g
,
const
std
::
string
obj
)
{
graph_info_map_
[
g
].
objects
.
push_back
(
obj
);
}
void
set_obj_node_map
(
FuncGraphPtr
g
,
const
std
::
string
obj
,
AnfNodePtr
node
)
{
...
...
@@ -141,7 +142,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
FuncGraphPtr
top_g_
;
FuncGraphPtr
df_builder_
;
FuncGraphPtr
curr_g_
;
std
::
unordered_map
<
std
::
string
,
AbstractListMap
>
prim_abs_list
;
std
::
unordered_map
<
std
::
string
,
AbstractListMap
>
prim_abs_list
_
;
};
using
PynativeExecutorPtr
=
std
::
shared_ptr
<
PynativeExecutor
>
;
...
...
mindspore/ccsrc/utils/primitive_py.cc
浏览文件 @
02d6e3a4
...
...
@@ -78,12 +78,19 @@ py::function PrimitivePy::GetBpropFunction() {
}
BaseRef
PrimitivePy
::
RunHookFunction
(
const
VectorRef
&
args
)
const
{
auto
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
tuple
py_args
=
ConvertDatatoPyTuple
(
args
);
py
::
object
obj
;
bool
is_bprop
=
this
->
HasAttr
(
kBpropAttrName
);
if
(
is_bprop
)
{
SyncData
(
py_args
);
obj
=
hook_
(
*
py_args
);
py
::
tuple
convert_args
(
py_args
.
size
());
for
(
size_t
i
=
0
;
i
<
py_args
.
size
();
i
++
)
{
convert_args
[
i
]
=
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
i
])
?
parse
::
python_adapter
::
CallPyFn
(
parse
::
PYTHON_MOD_PARSE_MODULE
,
parse
::
PYTHON_MOD_CONVERT_TO_MS_TENSOR
,
py_args
[
i
])
:
py_args
[
i
];
}
obj
=
hook_
(
*
convert_args
);
return
std
::
make_shared
<
PyObjectRef
>
(
obj
);
}
SyncData
(
py_args
[
2
]);
...
...
mindspore/common/tensor.py
浏览文件 @
02d6e3a4
...
...
@@ -210,12 +210,12 @@ class Tensor(Tensor_):
@
property
def
shape
(
self
):
"""The shape of tensor."""
"""The shape of tensor
is a tuple
."""
return
self
.
_shape
@
property
def
dtype
(
self
):
"""The dtype of tensor."""
"""The dtype of tensor
is a mindspore type
."""
return
self
.
_dtype
@
property
...
...
@@ -248,6 +248,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
"""
if
axis
is
None
:
axis
=
()
return
tensor_operator_registry
.
get
(
'all'
)(
keep_dims
)(
self
,
axis
)
def
any
(
self
,
axis
=
(),
keep_dims
=
False
):
...
...
@@ -264,6 +266,8 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
"""
if
axis
is
None
:
axis
=
()
return
tensor_operator_registry
.
get
(
'any'
)(
keep_dims
)(
self
,
axis
)
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
02d6e3a4
...
...
@@ -693,7 +693,7 @@ def get_bprop_unsorted_segment_min(self):
select
=
P
.
Select
()
def
bprop
(
x
,
segment_ids
,
num_segments
,
out
,
dout
):
gathered_outputs
,
zero_clipped_indices
,
is_positive
=
_GatherDropNegatives
(
out
,
segment_ids
)
gathered_outputs
,
zero_clipped_indices
,
is_positive
=
_GatherDropNegatives
(
out
,
segment_ids
,
None
,
None
)
is_selected
=
equal
(
x
,
gathered_outputs
)
is_selected
=
logical_and
(
is_selected
,
is_positive
)
num_selected
=
unsorted_segment_sum
(
cast
(
is_selected
,
get_dtype
(
dout
)),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录