Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
6af286b3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6af286b3
编写于
8月 27, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug the const input is broadened in PyNative mode
上级
e6ffdeeb
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
36 addition
and
27 deletion
+36
-27
mindspore/_extends/parse/parser.py
mindspore/_extends/parse/parser.py
+1
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+9
-5
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
+12
-13
mindspore/core/ir/primitive.cc
mindspore/core/ir/primitive.cc
+2
-2
mindspore/core/ir/primitive.h
mindspore/core/ir/primitive.h
+8
-3
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+1
-0
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-1
未找到文件。
mindspore/_extends/parse/parser.py
浏览文件 @
6af286b3
...
...
@@ -147,7 +147,7 @@ def resolve_symbol(namespace, symbol):
resolve_
=
namespace
[
symbol
]
# list and dict is not hashable ,it can not be key for the map, just return the result
if
isinstance
(
resolve_
,
(
list
,
dict
)):
if
isinstance
(
resolve_
,
(
tuple
,
list
,
dict
)):
return
resolve_
# dataclass may not be hashable
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
6af286b3
...
...
@@ -645,6 +645,9 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
inputs
.
push_back
(
NewValueNode
(
prim
));
size_t
size
=
op_exec_info
->
op_inputs
.
size
();
auto
const_input_index
=
prim
->
get_const_input_indexes
();
bool
have_const_input
=
!
const_input_index
.
empty
();
bool
is_const_prim
=
prim
->
is_const_prim
();
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
auto
obj
=
op_exec_info
->
op_inputs
[
i
];
bool
op_mask
=
false
;
...
...
@@ -672,12 +675,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
abs
=
node
->
abstract
();
}
MS_LOG
(
DEBUG
)
<<
prim
->
ToString
()
<<
" abs is nullptr "
<<
(
abs
==
nullptr
)
<<
" is_const_value "
<<
prim
->
is_const_value
();
if
(
abs
==
nullptr
||
prim
->
is_const_value
())
{
<<
prim
->
is_const_prim
();
bool
is_const_input
=
have_const_input
&&
std
::
count
(
const_input_index
.
begin
(),
const_input_index
.
end
(),
i
);
if
(
abs
==
nullptr
||
is_const_prim
||
is_const_input
)
{
MS_LOG
(
DEBUG
)
<<
"MakeCnode get node no in map"
<<
id
;
ValuePtr
input_value
=
PyAttrValue
(
obj
);
abs
=
input_value
->
ToAbstract
();
if
(
!
prim
->
is_const_value
()
)
{
if
(
!
is_const_prim
&&
!
is_const_input
)
{
auto
config
=
abstract
::
AbstractBase
::
kBroadenTensorOnly
;
abs
=
abs
->
Broaden
(
config
);
MS_LOG
(
DEBUG
)
<<
"broaden for "
<<
prim
->
ToString
()
<<
" "
<<
config
;
...
...
@@ -888,7 +892,7 @@ py::tuple PynativeExecutor::RunOpInner(const py::args &args) {
value_ret
[
0
]
=
output
[
"value"
];
return
value_ret
;
}
if
(
op_exec_info
->
py_primitive
->
is_const_
value
())
{
if
(
op_exec_info
->
py_primitive
->
is_const_
prim
())
{
py
::
tuple
value_ret
(
1
);
value_ret
[
0
]
=
""
;
return
value_ret
;
...
...
@@ -1041,7 +1045,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) {
auto
tuple
=
obj
.
cast
<
py
::
tuple
>
();
// cell((1,2)): support not mix (scalar, tensor)
if
(
tuple
.
size
()
>
0
&&
!
py
::
isinstance
<
tensor
::
Tensor
>
(
tuple
[
0
]))
{
if
(
!
tuple
.
empty
()
&&
!
py
::
isinstance
<
tensor
::
Tensor
>
(
tuple
[
0
]))
{
return
MakeValueNode
(
obj
,
obj_id
);
}
...
...
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
浏览文件 @
6af286b3
...
...
@@ -98,22 +98,22 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
<<
", and the value is "
<<
py
::
cast
<
py
::
str
>
(
grads
[
i
])
<<
"."
;
}
py
::
tuple
grad_shape
=
grads
[
i
].
attr
(
"sha
pe"
);
py
::
object
arg_dtype
=
py_args
[
i
].
attr
(
"dty
pe"
);
py
::
object
grad_dtype
=
grads
[
i
].
attr
(
"dtype"
);
py
::
tuple
arg_shape
=
py_args
[
i
].
attr
(
"shape"
);
py
::
object
arg_dtype
=
py_args
[
i
].
attr
(
"dtype"
);
py
::
tuple
grad_shape
=
grads
[
i
].
attr
(
"shape"
);
if
(
!
grad_dtype
.
equal
(
arg_dtype
))
{
MS_EXCEPTION
(
TypeError
)
<<
"When user defines the net bprop, the gradient of the "
<<
i
<<
"th arg should have the same dtype as the "
<<
i
<<
"th arg, but the "
<<
i
<<
"th arg dtype is: "
<<
py
::
cast
<
py
::
str
>
(
arg_dtype
)
<<
", the gradient dtype is: "
<<
py
::
cast
<
py
::
str
>
(
grad_dtype
)
<<
"."
;
}
if
(
!
grad_shape
.
equal
(
arg_shape
))
{
MS_EXCEPTION
(
ValueError
)
<<
"When user defines the net bprop, the gradient of the "
<<
i
<<
"th arg should have the same shape as the "
<<
i
<<
"th arg, but the "
<<
i
<<
"th arg shape is: "
<<
py
::
cast
<
py
::
str
>
(
arg_shape
)
<<
", the gradient shape is: "
<<
py
::
cast
<
py
::
str
>
(
grad_shape
)
<<
"."
;
}
if
(
!
grad_dtype
.
is
(
arg_dtype
))
{
MS_EXCEPTION
(
TypeError
)
<<
"When user defines the net bprop, the gradient of the "
<<
i
<<
"th arg should have the same dtype as the "
<<
i
<<
"th arg, but the "
<<
i
<<
"th arg dtype is: "
<<
py
::
cast
<
py
::
str
>
(
arg_dtype
)
<<
", the gradient dtype is: "
<<
py
::
cast
<
py
::
str
>
(
grad_dtype
)
<<
"."
;
}
}
}
}
...
...
@@ -239,10 +239,7 @@ py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
bool
PrimitivePy
::
HasComputeFunction
()
const
{
auto
func
=
GetComputeFunction
();
if
(
py
::
isinstance
<
py
::
none
>
(
func
))
{
return
false
;
}
return
true
;
return
!
py
::
isinstance
<
py
::
none
>
(
func
);
}
PrimitivePtr
PrimitivePy
::
Clone
()
{
...
...
@@ -272,7 +269,9 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.
def
(
"add_attr"
,
&
PrimitivePy
::
AddPyAttr
,
"add primitive attr"
)
.
def
(
"get_attr_dict"
,
&
PrimitivePy
::
GetAttrDict
,
"get primitive attr"
)
.
def
(
"set_prim_type"
,
&
PrimitivePy
::
set_prim_type
,
"Set primitive type."
)
.
def
(
"set_is_const_value"
,
&
PrimitivePy
::
set_is_const_value
,
"Set primitive is const value."
)
.
def
(
"set_const_prim"
,
&
PrimitivePy
::
set_const_prim
,
"Set primitive is const."
)
.
def
(
"set_const_input_indexes"
,
&
PrimitivePy
::
set_const_input_indexes
,
"Set primitive const input indexes."
)
.
def
(
"set_signatures"
,
&
PrimitivePy
::
set_signatures
,
"Set primitive inputs signature."
)
.
def
(
"register_hook"
,
&
PrimitivePy
::
set_hook
,
"Set primitive hook function."
)
.
def
(
"set_instance_name"
,
&
PrimitivePy
::
set_instance_name
,
"Set primitive instance name."
);
...
...
mindspore/core/ir/primitive.cc
浏览文件 @
6af286b3
...
...
@@ -32,7 +32,7 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType
has_signature_
(
false
),
prim_type_
(
prim_type
),
record_evaluate_add_attr_
(
false
),
is_const_
value
_
(
false
),
is_const_
prim
_
(
false
),
id_
(
MakeId
())
{}
Primitive
::
Primitive
(
const
Primitive
&
prim
)
...
...
@@ -43,7 +43,7 @@ Primitive::Primitive(const Primitive &prim)
has_signature_
(
prim
.
has_signature_
),
prim_type_
(
prim
.
prim_type_
),
record_evaluate_add_attr_
(
false
),
is_const_
value
_
(
false
),
is_const_
prim
_
(
false
),
id_
(
prim
.
id_
)
{}
abstract
::
AbstractBasePtr
Primitive
::
ToAbstract
()
{
...
...
mindspore/core/ir/primitive.h
浏览文件 @
6af286b3
...
...
@@ -109,8 +109,12 @@ class Primitive : public Named {
bool
is_base
()
const
{
return
is_base_
;
}
virtual
BaseRef
RunHookFunction
(
const
VectorRef
&
args
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
virtual
void
CopyHookFunction
(
const
PrimitivePtr
&
primitive
)
{
MS_LOG
(
EXCEPTION
)
<<
"call a empty function!"
;
}
void
set_is_const_value
(
bool
value
)
{
is_const_value_
=
value
;
}
bool
is_const_value
()
const
{
return
is_const_value_
;
}
void
set_const_prim
(
bool
is_const_prim
)
{
is_const_prim_
=
is_const_prim
;
}
bool
is_const_prim
()
const
{
return
is_const_prim_
;
}
void
set_const_input_indexes
(
const
std
::
vector
<
size_t
>
&
const_input_indexes
)
{
const_input_indexes_
=
const_input_indexes
;
}
std
::
vector
<
size_t
>
&
get_const_input_indexes
()
{
return
const_input_indexes_
;
}
std
::
string
id
()
const
{
return
id_
;
}
protected:
...
...
@@ -123,7 +127,8 @@ class Primitive : public Named {
bool
has_signature_
;
PrimType
prim_type_
;
bool
record_evaluate_add_attr_
;
bool
is_const_value_
;
bool
is_const_prim_
;
std
::
vector
<
size_t
>
const_input_indexes_
;
std
::
string
id_
{
""
};
};
...
...
mindspore/ops/functional.py
浏览文件 @
6af286b3
...
...
@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast
=
P
.
Cast
()
dtype
=
P
.
DType
()
isconstant
=
Primitive
(
'is_constant'
)
isconstant
.
set_
is_const_value
(
True
)
isconstant
.
set_
const_prim
(
True
)
issubclass_
=
P
.
IsSubClass
()
isinstance_
=
P
.
IsInstance
()
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
6af286b3
...
...
@@ -1089,7 +1089,7 @@ class InvertPermutation(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
):
"""init InvertPermutation"""
self
.
set_
is_const_value
(
True
)
self
.
set_
const_prim
(
True
)
def
__infer__
(
self
,
x
):
x_shp
=
x
[
'shape'
]
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
6af286b3
...
...
@@ -2866,6 +2866,7 @@ class MirrorPad(PrimitiveWithInfer):
"""Init Pad"""
validator
.
check_string
(
'mode'
,
mode
,
[
'REFLECT'
,
'SYMMETRIC'
],
self
.
name
)
self
.
mode
=
mode
self
.
set_const_input_indexes
([
1
])
def
__infer__
(
self
,
input_x
,
paddings
):
validator
.
check_subclass
(
"input_x"
,
input_x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
...
...
mindspore/ops/primitive.py
浏览文件 @
6af286b3
...
...
@@ -390,7 +390,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def
__init__
(
self
):
op_name
=
name
if
name
else
fn
.
__name__
PrimitiveWithInfer
.
__init__
(
self
,
op_name
)
self
.
set_
is_const_value
(
True
)
self
.
set_
const_prim
(
True
)
def
infer_value
(
self
,
*
args
):
return
fn
(
*
args
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录