Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
be3d79cb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
be3d79cb
编写于
9月 02, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4204 add dynamic shape support for GatherV2 and others
Merge pull request !4204 from fary86/adapt_primitive_dynamic_shape
上级
cd5a1bf1
144a35b1
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
549 addition
and
102 deletion
+549
-102
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+114
-61
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
+2
-0
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
...ore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
+5
-7
mindspore/ccsrc/pipeline/jit/validator.cc
mindspore/ccsrc/pipeline/jit/validator.cc
+4
-0
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
+19
-2
mindspore/ccsrc/pybind_api/ir/primitive_py.h
mindspore/ccsrc/pybind_api/ir/primitive_py.h
+2
-0
mindspore/core/abstract/infer_functions.h
mindspore/core/abstract/infer_functions.h
+11
-0
mindspore/core/abstract/prim_arrays.cc
mindspore/core/abstract/prim_arrays.cc
+57
-0
mindspore/core/abstract/prim_maths.cc
mindspore/core/abstract/prim_maths.cc
+9
-0
mindspore/core/abstract/prim_nn.cc
mindspore/core/abstract/prim_nn.cc
+20
-0
mindspore/core/abstract/primitive_infer_map.cc
mindspore/core/abstract/primitive_infer_map.cc
+6
-0
mindspore/core/base/core_ops.h
mindspore/core/base/core_ops.h
+5
-0
mindspore/core/ir/primitive.h
mindspore/core/ir/primitive.h
+2
-1
mindspore/core/utils/flags.cc
mindspore/core/utils/flags.cc
+15
-0
mindspore/core/utils/flags.h
mindspore/core/utils/flags.h
+10
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+32
-10
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+3
-7
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+7
-11
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+92
-0
tests/ut/python/ir/test_row_tensor.py
tests/ut/python/ir/test_row_tensor.py
+23
-2
tests/ut/python/ops/test_dynamic_shape.py
tests/ut/python/ops/test_dynamic_shape.py
+109
-0
未找到文件。
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
be3d79cb
...
...
@@ -49,22 +49,6 @@ using mindspore::parse::PyObjectWrapper;
std
::
unordered_set
<
std
::
string
>
prims_to_skip_undetermined_infer
{
"make_tuple"
,
"make_list"
,
"switch"
,
"env_setitem"
,
"env_getitem"
};
EvalResultPtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
if
(
prims_to_skip_undetermined_infer
.
find
(
prim_
->
name
())
==
prims_to_skip_undetermined_infer
.
end
())
{
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"StandardPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
prim_
->
BeginRecordAddAttr
();
AbstractBasePtr
abs_base
=
eval_impl_
(
engine
,
prim_
,
args
);
prim_
->
EndRecordAddAttr
();
auto
added_attrs
=
prim_
->
evaluate_added_attrs
();
auto
infer_result
=
std
::
make_shared
<
EvalResult
>
(
abs_base
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
return
infer_result
;
}
EvalResultPtr
DoSignatureEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
...
...
@@ -289,45 +273,45 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
py
::
dict
dic
;
if
(
abs_base
->
isa
<
AbstractTensor
>
())
{
auto
arg_tensor
=
dyn_cast
<
AbstractTensor
>
(
abs_base
);
dic
[
"shape"
]
=
arg_tensor
->
shape
()
->
shape
();
dic
[
ATTR_SHAPE
]
=
arg_tensor
->
shape
()
->
shape
();
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kGraphMode
)
{
const
auto
&
min_shape
=
arg_tensor
->
shape
()
->
min_shape
();
const
auto
&
max_shape
=
arg_tensor
->
shape
()
->
max_shape
();
if
(
!
min_shape
.
empty
()
&&
!
max_shape
.
empty
())
{
dic
[
"min_shape"
]
=
min_shape
;
dic
[
"max_shape"
]
=
max_shape
;
dic
[
ATTR_MIN_SHAPE
]
=
min_shape
;
dic
[
ATTR_MAX_SHAPE
]
=
max_shape
;
}
}
dic
[
"dtype"
]
=
arg_tensor
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_tensor
->
BuildValue
());
dic
[
ATTR_DTYPE
]
=
arg_tensor
->
BuildType
();
dic
[
ATTR_VALUE
]
=
BuildValue
(
arg_tensor
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractRowTensor
>
())
{
auto
arg
=
dyn_cast
<
AbstractRowTensor
>
(
abs_base
);
dic
[
"shape"
]
=
arg
->
shape
()
->
shape
();
dic
[
"dtype"
]
=
arg
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg
->
BuildValue
());
dic
[
ATTR_SHAPE
]
=
arg
->
shape
()
->
shape
();
dic
[
ATTR_DTYPE
]
=
arg
->
BuildType
();
dic
[
ATTR_VALUE
]
=
BuildValue
(
arg
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractSparseTensor
>
())
{
auto
arg
=
dyn_cast
<
AbstractSparseTensor
>
(
abs_base
);
dic
[
"shape"
]
=
arg
->
shape
()
->
shape
();
dic
[
"dtype"
]
=
arg
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg
->
BuildValue
());
dic
[
ATTR_SHAPE
]
=
arg
->
shape
()
->
shape
();
dic
[
ATTR_DTYPE
]
=
arg
->
BuildType
();
dic
[
ATTR_VALUE
]
=
BuildValue
(
arg
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractScalar
>
()
||
abs_base
->
isa
<
AbstractType
>
()
||
abs_base
->
isa
<
AbstractRefKey
>
())
{
ShapeVector
shape
;
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
abs_base
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
abs_base
->
BuildValue
());
dic
[
ATTR_SHAPE
]
=
shape
;
dic
[
ATTR_DTYPE
]
=
abs_base
->
BuildType
();
dic
[
ATTR_VALUE
]
=
BuildValue
(
abs_base
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractSlice
>
())
{
auto
arg_slice
=
dyn_cast
<
AbstractSlice
>
(
abs_base
);
ShapeVector
shape
;
dic
[
"shape"
]
=
shape
;
dic
[
"dtype"
]
=
arg_slice
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_slice
->
BuildValue
());
dic
[
ATTR_SHAPE
]
=
shape
;
dic
[
ATTR_DTYPE
]
=
arg_slice
->
BuildType
();
dic
[
ATTR_VALUE
]
=
BuildValue
(
arg_slice
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractRef
>
())
{
auto
value
=
abs_base
->
cast
<
AbstractRefPtr
>
()
->
ref
();
dic
=
ConvertAbstractToPython
(
value
);
}
else
if
(
abs_base
->
isa
<
AbstractEllipsis
>
())
{
dic
[
"shape"
]
=
py
::
none
();
dic
[
"dtype"
]
=
py
::
ellipsis
();
dic
[
"value"
]
=
py
::
ellipsis
();
dic
[
ATTR_SHAPE
]
=
py
::
none
();
dic
[
ATTR_DTYPE
]
=
py
::
ellipsis
();
dic
[
ATTR_VALUE
]
=
py
::
ellipsis
();
}
else
if
(
abs_base
->
isa
<
AbstractTuple
>
())
{
auto
arg_tuple
=
dyn_cast
<
AbstractTuple
>
(
abs_base
);
size_t
len
=
arg_tuple
->
size
();
...
...
@@ -336,12 +320,12 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
py
::
dict
out
=
ConvertAbstractToPython
(
arg_tuple
->
elements
()[
i
]);
shape_tuple
[
i
]
=
out
[
"shape"
];
dtype_tuple
[
i
]
=
out
[
"dtype"
];
shape_tuple
[
i
]
=
out
[
ATTR_SHAPE
];
dtype_tuple
[
i
]
=
out
[
ATTR_DTYPE
];
}
dic
[
"shape"
]
=
shape_tuple
;
dic
[
"dtype"
]
=
dtype_tuple
;
dic
[
"value"
]
=
BuildValue
(
arg_tuple
->
BuildValue
());
dic
[
ATTR_SHAPE
]
=
shape_tuple
;
dic
[
ATTR_DTYPE
]
=
dtype_tuple
;
dic
[
ATTR_VALUE
]
=
BuildValue
(
arg_tuple
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractList
>
())
{
auto
arg_list
=
dyn_cast
<
AbstractList
>
(
abs_base
);
size_t
len
=
arg_list
->
size
();
...
...
@@ -350,25 +334,25 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
for
(
size_t
i
=
0
;
i
<
len
;
i
++
)
{
py
::
dict
out
=
ConvertAbstractToPython
(
arg_list
->
elements
()[
i
]);
shape_list
[
i
]
=
out
[
"shape"
];
dtype_list
[
i
]
=
out
[
"dtype"
];
shape_list
[
i
]
=
out
[
ATTR_SHAPE
];
dtype_list
[
i
]
=
out
[
ATTR_DTYPE
];
}
dic
[
"shape"
]
=
shape_list
;
dic
[
"dtype"
]
=
dtype_list
;
dic
[
"value"
]
=
BuildValue
(
arg_list
->
BuildValue
());
dic
[
ATTR_SHAPE
]
=
shape_list
;
dic
[
ATTR_DTYPE
]
=
dtype_list
;
dic
[
ATTR_VALUE
]
=
BuildValue
(
arg_list
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractNone
>
())
{
dic
[
"shape"
]
=
py
::
none
();
dic
[
"dtype"
]
=
py
::
none
();
dic
[
"value"
]
=
py
::
none
();
dic
[
ATTR_SHAPE
]
=
py
::
none
();
dic
[
ATTR_DTYPE
]
=
py
::
none
();
dic
[
ATTR_VALUE
]
=
py
::
none
();
}
else
if
(
abs_base
->
isa
<
AbstractFunction
>
())
{
dic
[
"shape"
]
=
py
::
none
();
dic
[
"dtype"
]
=
abs_base
->
BuildType
();
dic
[
"value"
]
=
py
::
none
();
dic
[
ATTR_SHAPE
]
=
py
::
none
();
dic
[
ATTR_DTYPE
]
=
abs_base
->
BuildType
();
dic
[
ATTR_VALUE
]
=
py
::
none
();
}
else
if
(
abs_base
->
isa
<
AbstractUndetermined
>
())
{
auto
arg
=
dyn_cast
<
AbstractUndetermined
>
(
abs_base
);
dic
[
"shape"
]
=
py
::
none
();
dic
[
"dtype"
]
=
arg
->
BuildType
();
dic
[
"value"
]
=
py
::
none
();
dic
[
ATTR_SHAPE
]
=
py
::
none
();
dic
[
ATTR_DTYPE
]
=
arg
->
BuildType
();
dic
[
ATTR_VALUE
]
=
py
::
none
();
}
else
{
auto
value
=
abs_base
->
BuildValue
();
if
((
*
value
==
*
kAnyValue
))
{
...
...
@@ -409,18 +393,20 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
AbstractBasePtr
PyInferRes2Abstract
(
const
PrimitivePyPtr
&
prim_py
,
const
py
::
dict
&
output
)
{
// Convert to AbstractValue based on type and shape
auto
out_dtype
=
output
[
"dtype"
];
if
(
output
[
"value"
].
is_none
())
{
auto
out_shape
=
output
[
"shape"
];
py
::
object
min_shape
=
output
.
contains
(
"min_shape"
)
?
(
py
::
object
)
output
[
"min_shape"
]
:
(
py
::
object
)
py
::
none
();
py
::
object
max_shape
=
output
.
contains
(
"max_shape"
)
?
(
py
::
object
)
output
[
"max_shape"
]
:
(
py
::
object
)
py
::
none
();
auto
out_dtype
=
output
[
ATTR_DTYPE
];
if
(
output
[
ATTR_VALUE
].
is_none
())
{
auto
out_shape
=
output
[
ATTR_SHAPE
];
py
::
object
min_shape
=
output
.
contains
(
py
::
str
(
ATTR_MIN_SHAPE
))
?
(
py
::
object
)
output
[
ATTR_MIN_SHAPE
]
:
(
py
::
object
)
py
::
none
();
py
::
object
max_shape
=
output
.
contains
(
py
::
str
(
ATTR_MAX_SHAPE
))
?
(
py
::
object
)
output
[
ATTR_MAX_SHAPE
]
:
(
py
::
object
)
py
::
none
();
return
PyListDtype2AbstractTensor
(
out_shape
,
out_dtype
,
min_shape
,
max_shape
);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr
converted_ret
=
nullptr
;
TypePtr
dtype
=
py
::
isinstance
<
Type
>
(
out_dtype
)
?
out_dtype
.
cast
<
TypePtr
>
()
:
nullptr
;
bool
converted
=
parse
::
ConvertData
(
output
[
"value"
],
&
converted_ret
,
false
,
dtype
);
bool
converted
=
parse
::
ConvertData
(
output
[
ATTR_VALUE
],
&
converted_ret
,
false
,
dtype
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Convert data failed"
;
}
...
...
@@ -447,6 +433,73 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
}
// end anonymous namespace
EvalResultPtr
StandardPrimEvaluator
::
EvalPyCheckPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
auto
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim_
);
if
(
prim_py
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"The primitive with type 'kPrimTypePyInferCheck' should be a python primitive."
;
}
// Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
MS_LOG
(
DEBUG
)
<<
"Begin input args checking for: "
<<
prim_py
->
ToString
();
auto
py_args
=
PreparePyInputs
(
prim_py
,
args
);
prim_py
->
RunCheck
(
py_args
);
prim_
->
BeginRecordAddAttr
();
AbstractBasePtr
abs_base
=
eval_impl_
(
engine
,
prim_
,
args
);
prim_
->
EndRecordAddAttr
();
auto
added_attrs
=
prim_
->
evaluate_added_attrs
();
if
(
!
py
::
hasattr
(
prim_py
->
GetPyObj
(),
PY_PRIM_METHOD_INFER_VALUE
))
{
return
std
::
make_shared
<
EvalResult
>
(
abs_base
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
}
// Call method 'infer_value' for primitive with this method for constant propagation
py
::
tuple
py_vals
(
py_args
.
size
());
for
(
size_t
i
=
0
;
i
<
py_args
.
size
();
++
i
)
{
py_vals
[
i
]
=
py_args
[
i
][
ATTR_VALUE
];
}
py
::
object
py_ret
=
prim_py
->
RunInferValue
(
py_vals
);
if
(
py
::
isinstance
<
py
::
none
>
(
py_ret
))
{
return
std
::
make_shared
<
EvalResult
>
(
abs_base
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr
converted_ret
=
nullptr
;
TypePtr
dtype
=
abs_base
->
BuildType
();
bool
converted
=
parse
::
ConvertData
(
py_ret
,
&
converted_ret
,
false
,
dtype
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Convert data failed"
;
}
auto
res_spec
=
FromValue
(
converted_ret
);
MS_EXCEPTION_IF_NULL
(
res_spec
);
if
(
res_spec
->
isa
<
AbstractTensor
>
())
{
// Replace to tensor constant node in specialize
auto
res_tensor
=
res_spec
->
cast
<
AbstractTensorPtr
>
();
res_tensor
->
set_value
(
converted_ret
);
}
return
std
::
make_shared
<
EvalResult
>
(
res_spec
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
}
EvalResultPtr
StandardPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
)
{
if
(
prims_to_skip_undetermined_infer
.
find
(
prim_
->
name
())
==
prims_to_skip_undetermined_infer
.
end
())
{
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"StandardPrimEvaluator eval Undetermined"
;
return
ret_abstract
;
}
}
if
(
prim_
->
prim_type
()
==
PrimType
::
kPrimTypePyInferCheck
)
{
return
EvalPyCheckPrim
(
engine
,
args
);
}
prim_
->
BeginRecordAddAttr
();
AbstractBasePtr
abs_base
=
eval_impl_
(
engine
,
prim_
,
args
);
prim_
->
EndRecordAddAttr
();
auto
added_attrs
=
prim_
->
evaluate_added_attrs
();
return
std
::
make_shared
<
EvalResult
>
(
abs_base
,
std
::
make_shared
<
AttrValueMap
>
(
added_attrs
));
}
EvalResultPtr
PythonPrimEvaluator
::
EvalPrim
(
const
AnalysisEnginePtr
&
,
const
AbstractBasePtrList
&
args
)
{
auto
ret_abstract
=
AbstractEval
(
args
);
if
(
ret_abstract
!=
nullptr
)
{
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
浏览文件 @
be3d79cb
...
...
@@ -42,6 +42,8 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
std
::
string
ToString
()
const
override
{
return
identifier_
+
prim_
->
name
();
}
private:
EvalResultPtr
EvalPyCheckPrim
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args
);
PrimitivePtr
prim_
;
const
StandardPrimitiveEvalImpl
eval_impl_
;
};
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc
浏览文件 @
be3d79cb
...
...
@@ -308,20 +308,18 @@ void AnalysisEngine::Clear() {
namespace
{
EvaluatorPtr
GetPrimEvaluator
(
const
PrimitivePtr
&
prim
,
const
AnalysisEnginePtr
&
engine
)
{
// Custom Primitive with python infer_shape, infer_type
EvaluatorPtr
evaluator
=
nullptr
;
MS_EXCEPTION_IF_NULL
(
prim
);
if
(
prim
->
isa
<
prim
::
DoSignaturePrimitive
>
())
{
evaluator
=
std
::
make_shared
<
DoSignatureEvaluator
>
(
prim
);
return
evaluator
;
return
std
::
make_shared
<
DoSignatureEvaluator
>
(
prim
);
}
if
(
prim
->
isa
<
prim
::
UnpackGraphPrimitive
>
())
{
evaluator
=
std
::
make_shared
<
UnpackGraphEvaluator
>
(
prim
);
return
evaluator
;
return
std
::
make_shared
<
UnpackGraphEvaluator
>
(
prim
);
}
if
(
prim
->
Hash
()
==
prim
::
kPrimMixedPrecisionCast
->
Hash
()
&&
prim
->
name
()
==
prim
::
kPrimMixedPrecisionCast
->
name
())
{
evaluator
=
std
::
make_shared
<
MixedPrecisionCastEvaluator
>
(
prim
);
return
evaluator
;
return
std
::
make_shared
<
MixedPrecisionCastEvaluator
>
(
prim
);
}
EvaluatorPtr
evaluator
=
nullptr
;
if
(
prim
->
HasPyEvaluator
())
{
auto
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim
);
if
(
prim_py
!=
nullptr
)
{
...
...
mindspore/ccsrc/pipeline/jit/validator.cc
浏览文件 @
be3d79cb
...
...
@@ -55,6 +55,10 @@ void ValidateOperation(const AnfNodePtr &node) {
MS_LOG
(
DEBUG
)
<<
"Primitive "
<<
prim
->
name
()
<<
" has python evaluator."
;
return
;
}
if
(
prim
->
prim_type
()
==
PrimType
::
kPrimTypePyInferCheck
)
{
MS_LOG
(
DEBUG
)
<<
"Primitive "
<<
prim
->
name
()
<<
" has python inference checking method."
;
return
;
}
if
(
prim
->
name
()
==
"fake_bprop"
)
{
MS_LOG
(
EXCEPTION
)
<<
"Illegal primitive: "
<<
GetValue
<
std
::
string
>
(
prim
->
GetAttr
(
"info"
));
}
...
...
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
浏览文件 @
be3d79cb
...
...
@@ -254,16 +254,33 @@ py::dict PrimitivePy::RunInfer(const py::tuple &args) {
if
(
!
HasPyObj
())
{
MS_LOG
(
EXCEPTION
)
<<
"["
<<
this
->
ToString
()
<<
"]: pyobj is empty"
;
}
auto
infer_fuc
=
python_obj_
.
attr
(
"__infer__"
);
auto
infer_fuc
=
python_obj_
.
attr
(
PY_PRIM_METHOD_INFER
);
return
infer_fuc
(
*
args
);
}
void
PrimitivePy
::
RunCheck
(
const
py
::
tuple
&
args
)
{
if
(
!
HasPyObj
())
{
MS_LOG
(
EXCEPTION
)
<<
"["
<<
this
->
ToString
()
<<
"]: pyobj is empty"
;
}
auto
check_func
=
python_obj_
.
attr
(
PY_PRIM_METHOD_CHECK
);
(
void
)
check_func
(
*
args
);
}
py
::
object
PrimitivePy
::
RunInferValue
(
const
py
::
tuple
&
args
)
{
if
(
!
HasPyObj
())
{
MS_LOG
(
EXCEPTION
)
<<
"["
<<
this
->
ToString
()
<<
"]: pyobj is empty"
;
}
auto
infer_value
=
python_obj_
.
attr
(
PY_PRIM_METHOD_INFER_VALUE
);
return
infer_value
(
*
args
);
}
REGISTER_PYBIND_DEFINE
(
Primitive_
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
PrimType
>
(
*
m
,
"prim_type"
,
py
::
arithmetic
())
.
value
(
"unknown"
,
PrimType
::
kPrimTypeUnknown
)
.
value
(
"builtin"
,
PrimType
::
kPrimTypeBuiltIn
)
.
value
(
"py_infer_shape"
,
PrimType
::
kPrimTypePyInferShape
)
.
value
(
"user_custom"
,
PrimType
::
kPrimTypeUserCustom
);
.
value
(
"user_custom"
,
PrimType
::
kPrimTypeUserCustom
)
.
value
(
"py_infer_check"
,
PrimType
::
kPrimTypePyInferCheck
);
(
void
)
py
::
class_
<
PrimitivePy
,
std
::
shared_ptr
<
PrimitivePy
>>
(
*
m
,
"Primitive_"
)
.
def_readonly
(
PYTHON_PRIMITIVE_FLAG
,
&
PrimitivePy
::
parse_info_
)
.
def
(
py
::
init
<
py
::
str
&
,
py
::
object
>
())
...
...
mindspore/ccsrc/pybind_api/ir/primitive_py.h
浏览文件 @
be3d79cb
...
...
@@ -62,6 +62,8 @@ class PrimitivePy : public Primitive {
const
bool
parse_info_
=
true
;
const
py
::
object
&
GetPyObj
()
const
{
return
python_obj_
;
}
py
::
dict
RunInfer
(
const
py
::
tuple
&
args
);
void
RunCheck
(
const
py
::
tuple
&
args
);
py
::
object
RunInferValue
(
const
py
::
tuple
&
args
);
bool
ObjHasAttr
(
const
char
*
attr_name
)
{
return
py
::
hasattr
(
python_obj_
,
attr_name
);
}
bool
HasPyObj
()
{
return
python_obj_
.
operator
bool
();
}
PrimitivePtr
Clone
()
override
;
...
...
mindspore/core/abstract/infer_functions.h
浏览文件 @
be3d79cb
...
...
@@ -81,6 +81,9 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
AbstractBasePtr
InferImplMinOrMaxGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSqrt
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplScalarToArray
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplArrayToScalar
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
...
...
@@ -176,6 +179,14 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUnique
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGatherV2
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDynamicShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseApplyFtrl
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplSparseApplyProximalAdagrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
template
<
typename
T
>
AbstractBasePtr
InferTupleOrListOrDictLen
(
const
std
::
string
&
op_name
,
const
AbstractBasePtrList
&
args_spec_list
)
{
...
...
mindspore/core/abstract/prim_arrays.cc
浏览文件 @
be3d79cb
...
...
@@ -14,6 +14,8 @@
* limitations under the License.
*/
#include <algorithm>
#include <iterator>
#include "abstract/infer_functions.h"
#include "abstract/utils.h"
#include "abstract/param_validator.h"
...
...
@@ -226,5 +228,60 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
// outputs: dx
return
std
::
make_shared
<
AbstractTensor
>
(
ids
->
element
(),
ids_idx
->
shape
());
}
AbstractBasePtr
InferImplGatherV2
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
std
::
string
&
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
3
);
AbstractTensorPtr
params
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
AbstractTensorPtr
indices
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
1
);
AbstractScalarPtr
axis
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
2
);
auto
params_shp
=
params
->
shape
()
->
shape
();
auto
indices_shp
=
indices
->
shape
()
->
shape
();
auto
axis_val
=
GetValue
<
int
>
(
axis
->
BuildValue
());
auto
params_rank
=
static_cast
<
int
>
(
params_shp
.
size
());
if
(
axis_val
<
0
)
{
axis_val
+=
params_rank
;
}
auto
calc_shape
=
[
axis_val
,
&
params_shp
](
const
ShapeVector
&
inp_vec
)
->
ShapeVector
{
ShapeVector
out_vec
;
std
::
copy
(
params_shp
.
begin
(),
params_shp
.
begin
()
+
axis_val
,
std
::
back_inserter
(
out_vec
));
copy
(
inp_vec
.
begin
(),
inp_vec
.
end
(),
std
::
back_inserter
(
out_vec
));
copy
(
params_shp
.
begin
()
+
axis_val
+
1
,
params_shp
.
end
(),
std
::
back_inserter
(
out_vec
));
return
out_vec
;
};
ShapeVector
out_shape
=
calc_shape
(
indices_shp
);
if
(
!
indices
->
shape
()
->
min_shape
().
empty
()
&&
!
indices
->
shape
()
->
max_shape
().
empty
())
{
ShapeVector
min_shape
=
calc_shape
(
indices
->
shape
()
->
min_shape
());
ShapeVector
max_shape
=
calc_shape
(
indices
->
shape
()
->
max_shape
());
return
std
::
make_shared
<
AbstractTensor
>
(
params
->
element
(),
std
::
make_shared
<
Shape
>
(
out_shape
,
min_shape
,
max_shape
));
}
return
std
::
make_shared
<
AbstractTensor
>
(
params
->
element
(),
std
::
make_shared
<
Shape
>
(
out_shape
));
}
AbstractBasePtr
InferImplDynamicShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
std
::
string
&
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
AbstractTensorPtr
input
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
auto
shape
=
input
->
shape
()
->
shape
();
bool
has_dyn_shape
=
std
::
any_of
(
shape
.
begin
(),
shape
.
end
(),
[](
int
dim
)
{
return
dim
==
Shape
::
SHP_ANY
;
});
std
::
vector
<
int
>
tensor_shp
({
static_cast
<
int
>
(
shape
.
size
())});
if
(
has_dyn_shape
)
{
auto
elem
=
std
::
make_shared
<
AbstractScalar
>
(
std
::
make_shared
<
AnyValue
>
(),
std
::
make_shared
<
Int
>
(
32
));
return
std
::
make_shared
<
AbstractTensor
>
(
elem
,
std
::
make_shared
<
Shape
>
(
tensor_shp
));
}
auto
shp_buf_size
=
sizeof
(
int
)
*
shape
.
size
();
auto
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
kNumberTypeInt32
,
tensor_shp
,
shape
.
data
(),
shp_buf_size
);
return
tensor
->
ToAbstract
();
}
}
// namespace abstract
}
// namespace mindspore
mindspore/core/abstract/prim_maths.cc
浏览文件 @
be3d79cb
...
...
@@ -37,5 +37,14 @@ AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const Primitive
return
std
::
make_shared
<
AbstractTuple
>
(
AbstractBasePtrList
({
dx
,
dy
}));
}
AbstractBasePtr
InferImplSqrt
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: three tensors.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
auto
inp
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
return
inp
->
Clone
()
->
Broaden
();
}
}
// namespace abstract
}
// namespace mindspore
mindspore/core/abstract/prim_nn.cc
浏览文件 @
be3d79cb
...
...
@@ -445,5 +445,25 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
return
std
::
make_shared
<
AbstractTensor
>
(
std
::
make_shared
<
AbstractScalar
>
(
kAnyValue
,
kUInt8
),
std
::
make_shared
<
Shape
>
(
std
::
vector
<
int64_t
>
{
shape_y
}));
}
AbstractBasePtr
InferImplSparseApplyFtrl
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
5
);
AbstractBasePtrList
elements
;
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
elements
.
push_back
(
args_spec_list
[
i
]
->
Clone
()
->
Broaden
());
}
return
std
::
make_shared
<
AbstractTuple
>
(
elements
);
}
AbstractBasePtr
InferImplSparseApplyProximalAdagrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
CheckArgsSize
(
primitive
->
name
(),
args_spec_list
,
7
);
AbstractBasePtrList
elements
;
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
elements
.
push_back
(
args_spec_list
[
i
]
->
Clone
()
->
Broaden
());
}
return
std
::
make_shared
<
AbstractTuple
>
(
elements
);
}
}
// namespace abstract
}
// namespace mindspore
mindspore/core/abstract/primitive_infer_map.cc
浏览文件 @
be3d79cb
...
...
@@ -37,6 +37,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
// Maths
{
prim
::
kPrimMaximumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimMinimumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimSqrt
,
{
InferImplSqrt
,
true
}},
// Array
{
prim
::
kPrimScalarToArray
,
{
InferImplScalarToArray
,
true
}},
{
prim
::
kPrimArrayToScalar
,
{
InferImplArrayToScalar
,
true
}},
...
...
@@ -44,6 +45,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimPack
,
{
InferImplPack
,
true
}},
{
prim
::
kPrimUnique
,
{
InferImplUnique
,
true
}},
{
prim
::
kPrimUniqueGrad
,
{
InferImplUniqueGrad
,
true
}},
{
prim
::
kPrimGatherV2
,
{
InferImplGatherV2
,
true
}},
{
prim
::
kPrimSparseGatherV2
,
{
InferImplGatherV2
,
true
}},
{
prim
::
kPrimDynamicShape
,
{
InferImplDynamicShape
,
true
}},
// Structure
{
prim
::
kPrimMakeTuple
,
{
InferImplMakeTuple
,
true
}},
{
prim
::
kPrimMakeList
,
{
InferImplMakeList
,
true
}},
...
...
@@ -77,6 +81,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimLayerNorm
,
{
InferImplLayerNorm
,
true
}},
{
prim
::
kPrimLayerNormGrad
,
{
InferImplLayerNormGrad
,
true
}},
{
prim
::
kPrimDropoutGenMask
,
{
InferImplDropoutGenMask
,
true
}},
{
prim
::
kPrimSparseApplyFtrl
,
{
InferImplSparseApplyFtrl
,
true
}},
{
prim
::
kPrimSparseApplyProximalAdagrad
,
{
InferImplSparseApplyProximalAdagrad
,
true
}},
// Others
{
prim
::
kPrimIdentity
,
{
InferImplIdentity
,
true
}},
// Set impl to null as it will use PartialEvaluator;
...
...
mindspore/core/base/core_ops.h
浏览文件 @
be3d79cb
...
...
@@ -84,6 +84,9 @@ inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
inline
const
PrimitivePtr
kPrimSqueeze
=
std
::
make_shared
<
Primitive
>
(
"Squeeze"
);
inline
const
PrimitivePtr
kPrimTranspose
=
std
::
make_shared
<
Primitive
>
(
"Transpose"
);
inline
const
PrimitivePtr
kPrimGatherV2
=
std
::
make_shared
<
Primitive
>
(
"GatherV2"
);
inline
const
PrimitivePtr
kPrimSparseGatherV2
=
std
::
make_shared
<
Primitive
>
(
"SparseGatherV2"
);
inline
const
PrimitivePtr
kPrimShape
=
std
::
make_shared
<
Primitive
>
(
"Shape"
);
inline
const
PrimitivePtr
kPrimDynamicShape
=
std
::
make_shared
<
Primitive
>
(
"DynamicShape"
);
inline
const
PrimitivePtr
kPrimEmbeddingLookup
=
std
::
make_shared
<
Primitive
>
(
"EmbeddingLookup"
);
inline
const
PrimitivePtr
kPrimEmbeddingLookupCommGrad
=
std
::
make_shared
<
Primitive
>
(
"EmbeddingLookupCommGrad"
);
inline
const
PrimitivePtr
kPrimSize
=
std
::
make_shared
<
Primitive
>
(
"Size"
);
...
...
@@ -154,6 +157,8 @@ inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut
inline
const
PrimitivePtr
kPrimFakeQuantPerLayer
=
std
::
make_shared
<
Primitive
>
(
"FakeQuantPerLayer"
);
inline
const
PrimitivePtr
kPrimFakeQuantPerChannel
=
std
::
make_shared
<
Primitive
>
(
"FakeQuantPerChannel"
);
inline
const
PrimitivePtr
kPrimApplyRMSProp
=
std
::
make_shared
<
Primitive
>
(
"ApplyRMSProp"
);
inline
const
PrimitivePtr
kPrimSparseApplyFtrl
=
std
::
make_shared
<
Primitive
>
(
"SparseApplyFtrl"
);
inline
const
PrimitivePtr
kPrimSparseApplyProximalAdagrad
=
std
::
make_shared
<
Primitive
>
(
"SparseApplyProximalAdagrad"
);
// Comm ops
inline
const
PrimitivePtr
kPrimMirror
=
std
::
make_shared
<
Primitive
>
(
"_MirrorOperator"
);
...
...
mindspore/core/ir/primitive.h
浏览文件 @
be3d79cb
...
...
@@ -35,7 +35,8 @@ enum PrimType {
kPrimTypeBuiltIn
,
// Built-in primitive operator
kPrimTypePyInferShape
,
// Primitive operator defined by custom
kPrimTypePyInferTensor
,
// Primitive operator defined by custom
kPrimTypeUserCustom
kPrimTypeUserCustom
,
kPrimTypePyInferCheck
// Primitive operator with input args checking method
};
class
Primitive
:
public
Named
{
...
...
mindspore/core/utils/flags.cc
浏览文件 @
be3d79cb
...
...
@@ -23,4 +23,19 @@ const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const
char
GRAPH_FLAG_EFFECT_PATIAL_ORDER
[]
=
"_effect_patial_order"
;
const
char
GRAPH_FLAG_RANDOM_EFFECT
[]
=
"_random_effect"
;
const
char
GRAPH_FLAG_SIDE_EFFECT
[]
=
"_side_effect"
;
// method names of python primitive called from c++ source code
// 1. infer method name of class 'PrimitiveWithInfer'
const
char
PY_PRIM_METHOD_INFER
[]
=
"__infer__"
;
// 2. check method name of class 'PrimitiveWithCheck'
const
char
PY_PRIM_METHOD_CHECK
[]
=
"__check__"
;
// 3. method name of class 'PrimitivePy' for constant propagation
const
char
PY_PRIM_METHOD_INFER_VALUE
[]
=
"infer_value"
;
// type inference related attributes
const
char
ATTR_VALUE
[]
=
"value"
;
const
char
ATTR_DTYPE
[]
=
"dtype"
;
const
char
ATTR_SHAPE
[]
=
"shape"
;
const
char
ATTR_MIN_SHAPE
[]
=
"min_shape"
;
const
char
ATTR_MAX_SHAPE
[]
=
"max_shape"
;
}
// namespace mindspore
mindspore/core/utils/flags.h
浏览文件 @
be3d79cb
...
...
@@ -23,6 +23,16 @@ extern const char GRAPH_FLAG_HAS_EFFECT[];
extern
const
char
GRAPH_FLAG_EFFECT_PATIAL_ORDER
[];
extern
const
char
GRAPH_FLAG_RANDOM_EFFECT
[];
extern
const
char
GRAPH_FLAG_SIDE_EFFECT
[];
extern
const
char
PY_PRIM_METHOD_INFER
[];
extern
const
char
PY_PRIM_METHOD_CHECK
[];
extern
const
char
PY_PRIM_METHOD_INFER_VALUE
[];
extern
const
char
ATTR_VALUE
[];
extern
const
char
ATTR_DTYPE
[];
extern
const
char
ATTR_SHAPE
[];
extern
const
char
ATTR_MIN_SHAPE
[];
extern
const
char
ATTR_MAX_SHAPE
[];
}
// namespace mindspore
#endif // MINDSPORE_CORE_UTILS_FLAGS_H
mindspore/ops/operations/__init__.py
浏览文件 @
be3d79cb
...
...
@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank
,
Reshape
,
ResizeNearestNeighbor
,
ArgMinWithValue
,
SameTypeShape
,
ScatterAdd
,
ScatterSub
,
ScatterMul
,
ScatterDiv
,
ScatterMax
,
ScatterMin
,
ScatterUpdate
,
ScalarToArray
,
ScalarToTensor
,
ScatterNd
,
ScatterNdUpdate
,
Select
,
Shape
,
Size
,
Slice
,
Split
,
TransShape
,
ParallelConcat
,
Padding
,
Shape
,
DynamicShape
,
Size
,
Slice
,
Split
,
TransShape
,
ParallelConcat
,
Padding
,
ScatterNdAdd
,
ScatterNdSub
,
ScatterNonAliasingAdd
,
ReverseV2
,
Rint
,
Squeeze
,
StridedSlice
,
Tile
,
TensorScatterUpdate
,
EditDistance
,
Transpose
,
TruncatedNormal
,
TupleToArray
,
UnsortedSegmentMin
,
UnsortedSegmentProd
,
...
...
@@ -206,6 +206,7 @@ __all__ = [
'HookBackward'
,
'InvertPermutation'
,
'Shape'
,
'DynamicShape'
,
'DropoutDoMask'
,
'DropoutGenMask'
,
'DropoutGrad'
,
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
be3d79cb
...
...
@@ -27,7 +27,7 @@ import numpy as np
from
.._utils
import
get_concat_offset
from
..operations.math_ops
import
_infer_shape_reduce
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
,
_run_op
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
PrimitiveWithCheck
,
prim_attr_register
,
_run_op
from
..._c_expression
import
signature_dtype
as
sig_dtype
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._c_expression
import
signature_rw
as
sig_rw
...
...
@@ -142,6 +142,11 @@ class ExpandDims(PrimitiveWithInfer):
out
=
{
'shape'
:
x_shape
,
'dtype'
:
x
[
'dtype'
],
'value'
:
value
}
if
'min_shape'
in
x
and
'max_shape'
in
x
:
out
[
'min_shape'
]
=
x
[
'min_shape'
]
out
[
'min_shape'
].
insert
(
axis_v
,
1
)
out
[
'max_shape'
]
=
x
[
'max_shape'
]
out
[
'max_shape'
].
insert
(
axis_v
,
1
)
return
out
...
...
@@ -277,6 +282,9 @@ class Cast(PrimitiveWithInfer):
out
=
{
'shape'
:
x
[
'shape'
],
'dtype'
:
mstype
.
tensor_type
(
t
[
'value'
]),
'value'
:
value
}
if
'min_shape'
in
x
and
'max_shape'
in
x
:
out
[
'min_shape'
]
=
x
[
'min_shape'
]
out
[
'max_shape'
]
=
x
[
'max_shape'
]
return
out
...
...
@@ -445,6 +453,27 @@ class Shape(PrimitiveWithInfer):
return
out
class
DynamicShape
(
Primitive
):
"""
Returns the shape of input tensor.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor[int], 1-dim Tensor of type int32
Examples:
>>> input_tensor = Tensor(np.ones(shape=[3, 2, 1]), mindspore.float32)
>>> shape = P.DynamicShape()
>>> output = shape(input_tensor)
"""
@
prim_attr_register
def
__init__
(
self
):
"""init Shape"""
class
Squeeze
(
PrimitiveWithInfer
):
"""
Returns a tensor with the same type but dimensions of 1 being removed based on axis.
...
...
@@ -578,7 +607,7 @@ class Unique(Primitive):
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
class
GatherV2
(
PrimitiveWith
Infer
):
class
GatherV2
(
PrimitiveWith
Check
):
"""
Returns a slice of input tensor based on the specified indices and axis.
...
...
@@ -605,7 +634,7 @@ class GatherV2(PrimitiveWithInfer):
"""init index_select"""
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'axis'
],
outputs
=
[
'output'
])
def
__
infer
__
(
self
,
params
,
indices
,
axis
):
def
__
check
__
(
self
,
params
,
indices
,
axis
):
validator
.
check_subclass
(
"params"
,
params
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_subclass
(
"axis"
,
axis
[
'dtype'
],
mstype
.
int_
,
self
.
name
)
...
...
@@ -613,13 +642,6 @@ class GatherV2(PrimitiveWithInfer):
params_shp
=
params
[
'shape'
]
rank
=
len
(
params_shp
)
validator
.
check_int_range
(
"axis"
,
axis_v
,
-
rank
,
rank
,
Rel
.
INC_LEFT
,
self
.
name
)
if
axis_v
<
0
:
axis_v
+=
rank
out_shape
=
params_shp
[:
axis_v
]
+
indices
[
'shape'
]
+
params_shp
[
axis_v
+
1
:]
out
=
{
'shape'
:
out_shape
,
'dtype'
:
params
[
'dtype'
],
'value'
:
None
}
return
out
class
SparseGatherV2
(
GatherV2
):
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
be3d79cb
...
...
@@ -26,7 +26,7 @@ from ..._checkparam import Rel
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
.._utils
import
get_broadcast_shape
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
,
_run_op
from
..primitive
import
PrimitiveWithInfer
,
PrimitiveWithCheck
,
prim_attr_register
,
_run_op
def
_infer_shape_reduce
(
x
,
axis
,
keep_dims
,
prim_name
):
...
...
@@ -1257,7 +1257,7 @@ class Rsqrt(PrimitiveWithInfer):
return
None
class
Sqrt
(
PrimitiveWith
Infer
):
class
Sqrt
(
PrimitiveWith
Check
):
"""
Returns square root of a tensor element-wise.
...
...
@@ -1279,12 +1279,8 @@ class Sqrt(PrimitiveWithInfer):
"""init Sqrt"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
check_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
def
infer_value
(
self
,
x
):
if
x
is
not
None
:
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
be3d79cb
...
...
@@ -28,7 +28,7 @@ from ..._c_expression import signature_dtype as sig_dtype
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
PrimitiveWithCheck
,
prim_attr_register
from
..operations.math_ops
import
_infer_shape_reduce
...
...
@@ -4354,7 +4354,7 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
return
var_dtype
,
accum_dtype
class
SparseApplyProximalAdagrad
(
PrimitiveWith
Infer
):
class
SparseApplyProximalAdagrad
(
PrimitiveWith
Check
):
r
"""
Update relevant entries according to the proximal adagrad algorithm. Compared with ApplyProximalAdagrad,
an additional index tensor is input.
...
...
@@ -4433,11 +4433,10 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
outputs
=
[
'var'
,
'accum'
])
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
def
infer
_shape
(
self
,
var_shape
,
accum_shape
,
lr_shape
,
l1_shape
,
l2_shape
,
grad_shape
,
indices_shape
):
def
check
_shape
(
self
,
var_shape
,
accum_shape
,
lr_shape
,
l1_shape
,
l2_shape
,
grad_shape
,
indices_shape
):
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
var_shape
,
accum_shape
def
infer
_dtype
(
self
,
var_dtype
,
accum_dtype
,
lr_dtype
,
l1_dtype
,
l2_dtype
,
grad_dtype
,
indices_dtype
):
def
check
_dtype
(
self
,
var_dtype
,
accum_dtype
,
lr_dtype
,
l1_dtype
,
l2_dtype
,
grad_dtype
,
indices_dtype
):
args
=
{
'var'
:
var_dtype
,
'accum'
:
accum_dtype
,
'grad'
:
grad_dtype
}
validator
.
check_tensor_type_same
(
args
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"lr"
:
lr_dtype
},
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
...
...
@@ -4446,7 +4445,6 @@ class SparseApplyProximalAdagrad(PrimitiveWithInfer):
valid_types
=
[
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint16
,
mstype
.
uint32
,
mstype
.
uint64
]
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
valid_types
,
self
.
name
)
return
var_dtype
,
accum_dtype
class
ApplyAddSign
(
PrimitiveWithInfer
):
...
...
@@ -4978,7 +4976,7 @@ class ApplyFtrl(PrimitiveWithInfer):
return
var_type
class
SparseApplyFtrl
(
PrimitiveWith
Infer
):
class
SparseApplyFtrl
(
PrimitiveWith
Check
):
"""
Update relevant entries according to the FTRL-proximal scheme.
...
...
@@ -5053,21 +5051,19 @@ class SparseApplyFtrl(PrimitiveWithInfer):
self
.
lr_power
=
validator
.
check_number
(
"lr_power"
,
lr_power
,
0
,
Rel
.
LE
,
self
.
name
)
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
def
infer
_shape
(
self
,
var_shape
,
accum_shape
,
linear_shape
,
grad_shape
,
indices_shape
):
def
check
_shape
(
self
,
var_shape
,
accum_shape
,
linear_shape
,
grad_shape
,
indices_shape
):
validator
.
check
(
'var shape'
,
var_shape
,
'accum shape'
,
accum_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'var shape'
,
var_shape
,
'linear shape'
,
linear_shape
,
Rel
.
EQ
,
self
.
name
)
if
len
(
var_shape
)
>
1
:
validator
.
check
(
'var_shape[1:]'
,
var_shape
[
1
:],
'grad_shape[1:]'
,
grad_shape
[
1
:],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'grad_shape[0]'
,
grad_shape
[
0
],
'indices_shape[0]'
,
indices_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
return
var_shape
,
accum_shape
,
linear_shape
def
infer
_dtype
(
self
,
var_dtype
,
accum_dtype
,
linear_dtype
,
grad_dtype
,
indices_dtype
):
def
check
_dtype
(
self
,
var_dtype
,
accum_dtype
,
linear_dtype
,
grad_dtype
,
indices_dtype
):
args
=
{
"var_dtype"
:
var_dtype
,
"accum_dtype"
:
accum_dtype
,
"linear_dtype"
:
linear_dtype
,
"grad_dtype"
:
grad_dtype
}
validator
.
check_tensor_type_same
(
args
,
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"indices_dtype"
:
indices_dtype
},
[
mstype
.
int32
],
self
.
name
)
return
var_dtype
,
accum_dtype
,
linear_dtype
class
SparseApplyFtrlV2
(
PrimitiveWithInfer
):
...
...
mindspore/ops/primitive.py
浏览文件 @
be3d79cb
...
...
@@ -200,6 +200,84 @@ class Primitive(Primitive_):
return
self
.
_update_parameter
class
PrimitiveWithCheck
(
Primitive
):
"""
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
but used the infer method registed in c++ source codes.
There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
If __check__() is not defined, infer_shape() and infer_dtype() can be defined to describe the check logic of
the shape and type.
Args:
name (str): Name of the current Primitive.
Examples:
>>> # init a Primitive class with check
>>> class Flatten(PrimitiveWithCheck):
>>> @prim_attr_register
>>> def __init__(self):
>>> pass
>>> def check_shape(self, input_x):
>>> validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name)
>>>
>>> def check_dtype(self, input_x):
>>> validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
>>>
>>> # init a Primitive obj
>>> add = Flatten()
"""
def
__init__
(
self
,
name
):
Primitive
.
__init__
(
self
,
name
)
self
.
set_prim_type
(
prim_type
.
py_infer_check
)
def
_clone
(
self
):
"""
Deeply clones the primitive object.
Calls the __init__() method with the same arguments. This method is called in parser if the
flag self.__setattr_flag__ is True.
"""
cloned_prim
=
Primitive
.
_clone
(
self
)
return
cloned_prim
def
check_shape
(
self
,
*
args
):
"""
Check shapes of input args.
Note:
The shape of scalar is an empty tuple.
Args:
args (tuple(int)): shapes of input tensors.
Return:
None.
"""
return
None
def
check_dtype
(
self
,
*
args
):
"""
Check data types of input args.
Args:
args (:class:`mindspore.dtype`): data type of inputs.
Return:
None.
"""
return
None
def
__check__
(
self
,
*
args
):
"""Check shape, type, and value at the same time by using dictionary as arguments."""
tracks
=
[
'dtype'
,
'shape'
]
for
track
in
tracks
:
fn
=
getattr
(
self
,
'check_'
+
track
)
fn
(
*
(
x
[
track
]
for
x
in
args
))
class
PrimitiveWithInfer
(
Primitive
):
"""
PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python.
...
...
@@ -306,6 +384,18 @@ class PrimitiveWithInfer(Primitive):
if
not
is_graph_mode
:
return
out
# output does not contain dynamic shape, no need to calculate min/max shape
def
has_dynamic_shape
(
shp
):
if
isinstance
(
shp
,
int
):
return
shp
<
0
if
isinstance
(
shp
,
(
list
,
tuple
)):
return
any
(
has_dynamic_shape
(
e
)
for
e
in
shp
)
return
False
if
not
has_dynamic_shape
(
out
[
'shape'
]):
return
out
# calculate min/max shape for output
def
get_specified_shape
(
elems
,
attr
):
has_specified_shape
=
False
ret_vals
=
[]
...
...
@@ -345,6 +435,8 @@ def prim_attr_register(fn):
def
deco
(
self
,
*
args
,
**
kwargs
):
if
isinstance
(
self
,
PrimitiveWithInfer
):
PrimitiveWithInfer
.
__init__
(
self
,
self
.
__class__
.
__name__
)
elif
isinstance
(
self
,
PrimitiveWithCheck
):
PrimitiveWithCheck
.
__init__
(
self
,
self
.
__class__
.
__name__
)
else
:
Primitive
.
__init__
(
self
,
self
.
__class__
.
__name__
)
bound_args
=
inspect
.
signature
(
fn
).
bind
(
self
,
*
args
,
**
kwargs
)
...
...
tests/ut/python/ir/test_row_tensor.py
浏览文件 @
be3d79cb
...
...
@@ -27,7 +27,7 @@ from mindspore.ops import composite as C
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.composite.multitype_ops.zeros_like_impl
import
zeros_like
from
mindspore.ops.primitive
import
constexpr
from
mindspore.ops.primitive
import
constexpr
,
PrimitiveWithInfer
,
prim_attr_register
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore
import
Tensor
,
RowTensor
,
context
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
...
...
@@ -105,10 +105,31 @@ def _generate_inverse_index(x_shape, axis):
perm
=
index
[
1
:
1
+
axis
]
+
(
0
,)
+
index
[
1
+
axis
:]
return
perm
class
MySparseGatherV2
(
P
.
GatherV2
):
# pylint: disable=W0231
class
MySparseGatherV2
(
PrimitiveWithInfer
):
"""
For test
"""
@
prim_attr_register
def
__init__
(
self
):
"""init index_select"""
self
.
init_prim_io_names
(
inputs
=
[
'params'
,
'indices'
,
'axis'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
params
,
indices
,
axis
):
validator
.
check_subclass
(
"params"
,
params
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_subclass
(
"axis"
,
axis
[
'dtype'
],
mstype
.
int_
,
self
.
name
)
axis_v
=
axis
[
'value'
]
params_shp
=
params
[
'shape'
]
rank
=
len
(
params_shp
)
validator
.
check_int_range
(
"axis"
,
axis_v
,
-
rank
,
rank
,
Rel
.
INC_LEFT
,
self
.
name
)
if
axis_v
<
0
:
axis_v
+=
rank
out_shape
=
params_shp
[:
axis_v
]
+
indices
[
'shape'
]
+
params_shp
[
axis_v
+
1
:]
out
=
{
'shape'
:
out_shape
,
'dtype'
:
params
[
'dtype'
],
'value'
:
None
}
return
out
@
bprop_getters
.
register
(
MySparseGatherV2
)
def
get_bprop_sparse_gather_v2
(
self
):
...
...
tests/ut/python/ops/test_dynamic_shape.py
0 → 100755
浏览文件 @
be3d79cb
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test dynamic shape """
from
mindspore
import
Tensor
,
context
,
nn
,
Parameter
from
mindspore.ops
import
operations
as
P
from
mindspore
import
dtype
as
mstype
import
numpy
as
np
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
False
)
def
test_sparse_apply_proximal_ada_grad
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
sparse_apply_proximal_adagrad
=
P
.
SparseApplyProximalAdagrad
()
self
.
var
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
7800
,
80
).
astype
(
np
.
float32
)),
name
=
"var"
)
self
.
accum
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
7800
,
80
).
astype
(
np
.
float32
)),
name
=
"accum"
)
self
.
lr
=
0.01
self
.
l1
=
0.0
self
.
l2
=
0.0
def
construct
(
self
,
grad
,
indices
):
out
=
self
.
sparse_apply_proximal_adagrad
(
self
.
var
,
self
.
accum
,
self
.
lr
,
self
.
l1
,
self
.
l2
,
grad
,
indices
)
return
out
[
0
]
class
NetWrapper
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetWrapper
,
self
).
__init__
()
self
.
unq
=
P
.
Unique
()
self
.
add
=
P
.
TensorAdd
()
self
.
expand_dims
=
P
.
ExpandDims
()
self
.
cast
=
P
.
Cast
()
self
.
net
=
Net
()
def
construct
(
self
,
grad
,
inp
):
ids
,
_
=
self
.
unq
(
inp
)
new_grad
=
self
.
expand_dims
(
ids
,
1
)
new_grad
=
self
.
cast
(
new_grad
,
mstype
.
float32
)
+
grad
return
self
.
net
(
new_grad
,
ids
)
net
=
NetWrapper
()
grad
=
Tensor
(
np
.
random
.
rand
(
1
,
80
).
astype
(
np
.
float32
))
indices
=
Tensor
(
np
.
ones
([
7800
]),
mstype
.
int32
)
net
(
grad
,
indices
)
def
test_sparse_apply_ftrl
():
class
SparseApplyFtrlNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseApplyFtrlNet
,
self
).
__init__
()
self
.
sparse_apply_ftrl
=
P
.
SparseApplyFtrl
(
lr
=
0.01
,
l1
=
0.0
,
l2
=
0.0
,
lr_power
=-
0.5
)
self
.
var
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
7800
,
80
).
astype
(
np
.
float32
)),
name
=
"var"
)
self
.
accum
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
7800
,
80
).
astype
(
np
.
float32
)),
name
=
"accum"
)
self
.
linear
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
7800
,
80
).
astype
(
np
.
float32
)),
name
=
"linear"
)
def
construct
(
self
,
grad
,
indices
):
out
=
self
.
sparse_apply_ftrl
(
self
.
var
,
self
.
accum
,
self
.
linear
,
grad
,
indices
)
return
out
[
0
]
class
NetWrapper
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetWrapper
,
self
).
__init__
()
self
.
unq
=
P
.
Unique
()
self
.
add
=
P
.
TensorAdd
()
self
.
expand_dims
=
P
.
ExpandDims
()
self
.
cast
=
P
.
Cast
()
self
.
net
=
SparseApplyFtrlNet
()
def
construct
(
self
,
grad
,
inp
):
ids
,
_
=
self
.
unq
(
inp
)
new_grad
=
self
.
expand_dims
(
ids
,
1
)
new_grad
=
self
.
cast
(
new_grad
,
mstype
.
float32
)
+
grad
return
self
.
net
(
new_grad
,
ids
)
net
=
NetWrapper
()
grad
=
Tensor
(
np
.
random
.
rand
(
1
,
80
).
astype
(
np
.
float32
))
indices
=
Tensor
(
np
.
ones
([
7800
]),
mstype
.
int32
)
net
(
grad
,
indices
)
def
test_gatherv2
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
unq
=
P
.
Unique
()
self
.
gather
=
P
.
GatherV2
()
def
construct
(
self
,
x
,
y
):
u
,
_
=
self
.
unq
(
y
)
z
=
self
.
gather
(
x
,
u
,
0
)
return
z
x
=
Tensor
(
np
.
ones
([
20
,
12
],
dtype
=
np
.
float32
))
y
=
Tensor
(
np
.
ones
([
8
],
dtype
=
np
.
int32
))
net
=
Net
()
net
(
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录