Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
b075674c
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b075674c
编写于
7月 31, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support tensor attr shape and dtype in graph mode
上级
fa96dfd1
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
211 addition
and
128 deletion
+211
-128
mindspore/_extends/parse/standard_method.py
mindspore/_extends/parse/standard_method.py
+2
-1
mindspore/ccsrc/frontend/operator/ops.h
mindspore/ccsrc/frontend/operator/ops.h
+0
-1
mindspore/ccsrc/frontend/operator/prim_arrays.cc
mindspore/ccsrc/frontend/operator/prim_arrays.cc
+0
-18
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+1
-0
mindspore/ccsrc/pipeline/jit/resource.cc
mindspore/ccsrc/pipeline/jit/resource.cc
+39
-22
mindspore/ccsrc/pipeline/jit/resource.h
mindspore/ccsrc/pipeline/jit/resource.h
+7
-3
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+32
-24
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
+0
-6
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+18
-10
tests/ut/cpp/operator/ops_test.cc
tests/ut/cpp/operator/ops_test.cc
+0
-5
tests/ut/cpp/pipeline/resource_test.cc
tests/ut/cpp/pipeline/resource_test.cc
+15
-15
tests/ut/cpp/pipeline/static_analysis/prim_test.cc
tests/ut/cpp/pipeline/static_analysis/prim_test.cc
+0
-18
tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py
.../ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py
+96
-0
tests/ut/python/pipeline/parse/test_super.py
tests/ut/python/pipeline/parse/test_super.py
+1
-5
未找到文件。
mindspore/_extends/parse/standard_method.py
浏览文件 @
b075674c
...
...
@@ -28,7 +28,8 @@ from ...ops.composite.base import _append
__all__
=
[
'MultitypeFuncGraph'
,
'env_get'
,
'hyper_add'
,
'zeros_like'
,
'ones_like'
]
trans
=
P
.
Transpose
()
shape_
=
P
.
Shape
()
dtype_
=
P
.
DType
()
def
transpose
(
x
):
"""Implementation of `transpose`."""
...
...
mindspore/ccsrc/frontend/operator/ops.h
浏览文件 @
b075674c
...
...
@@ -93,7 +93,6 @@ inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("arra
inline
const
PrimitivePtr
kPrimBroadcastShape
=
std
::
make_shared
<
Primitive
>
(
"broadcast_shape"
);
inline
const
PrimitivePtr
kPrimArrayMap
=
std
::
make_shared
<
Primitive
>
(
"array_map"
);
inline
const
PrimitivePtr
kPrimArrayReduce
=
std
::
make_shared
<
Primitive
>
(
"array_reduce"
);
inline
const
PrimitivePtr
kPrimShape
=
std
::
make_shared
<
Primitive
>
(
"Shape"
);
inline
const
PrimitivePtr
kPrimCast
=
std
::
make_shared
<
Primitive
>
(
"Cast"
);
inline
const
PrimitivePtr
kPrimConcat
=
std
::
make_shared
<
Primitive
>
(
"Concat"
);
inline
const
PrimitivePtr
kPrimSqueeze
=
std
::
make_shared
<
Primitive
>
(
"Squeeze"
);
...
...
mindspore/ccsrc/frontend/operator/prim_arrays.cc
浏览文件 @
b075674c
...
...
@@ -15,7 +15,6 @@
*/
#include "pipeline/jit/static_analysis/prim.h"
#include "frontend/operator/ops.h"
#include "abstract/utils.h"
#include "frontend/operator/cc_implementations.h"
#include "abstract/param_validator.h"
...
...
@@ -80,23 +79,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
return
std
::
make_shared
<
AbstractTuple
>
(
elems
);
}
AbstractBasePtr
InferImplShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
AbstractTensorPtr
arg
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
MS_LOG
(
DEBUG
)
<<
"InferImplShape:"
<<
arg
->
ToString
();
AbstractBasePtrList
values
;
auto
shp
=
arg
->
shape
();
for
(
int
entry
:
shp
->
shape
())
{
auto
entry_v
=
MakeValue
(
entry
);
values
.
push_back
(
std
::
make_shared
<
AbstractScalar
>
(
entry_v
,
entry_v
->
type
()));
}
return
std
::
make_shared
<
AbstractTuple
>
(
values
);
}
AbstractBasePtr
InferImplTile
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: a tensor and a tuple.
...
...
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
b075674c
...
...
@@ -963,6 +963,7 @@ void ClearResAtexit() {
abstract
::
ClearPrimEvaluatorMap
();
compile
::
ClearConvertCache
();
pipeline
::
GetMethodMap
().
clear
();
pipeline
::
GetAttrMap
().
clear
();
pipeline
::
ExecutorPy
::
ClearRes
();
pipeline
::
ReclaimOptimizer
();
pynative
::
PynativeExecutor
::
GetInstance
()
->
ClearRes
();
...
...
mindspore/ccsrc/pipeline/jit/resource.cc
浏览文件 @
b075674c
...
...
@@ -17,23 +17,20 @@
*/
#include "pipeline/jit/resource.h"
#include "pipeline/jit/pipeline.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "debug/draw.h"
#include "debug/trace.h"
#include "ir/dtype.h"
#include "pipeline/jit/parse/data_converter.h"
#include "frontend/operator/ops.h"
#include "ir/graph_utils.h"
#include "frontend/optimizer/ad/dfunctor.h"
#include "vm/segment_runner.h"
namespace
mindspore
{
// namespace to support opmap definition
namespace
pipeline
{
Method
Map
&
GetMethodMap
()
{
static
Method
Map
method_map
=
{
BuiltInType
Map
&
GetMethodMap
()
{
static
BuiltInType
Map
method_map
=
{
{
kObjectTypeString
,
{
{
"__bool__"
,
std
::
string
(
"str_bool"
)}
// C.str_bool
...
...
@@ -191,6 +188,15 @@ MethodMap &GetMethodMap() {
return
method_map
;
}
BuiltInTypeMap
&
GetAttrMap
()
{
static
BuiltInTypeMap
attr_map
=
{{
kObjectTypeTensorType
,
{
{
"shape"
,
std
::
string
(
"shape_"
)},
// C.shape_
{
"dtype"
,
std
::
string
(
"dtype_"
)},
// C.dtype_
}}};
return
attr_map
;
}
Resource
::
Resource
(
const
py
::
object
&
obj
)
:
engine_
(
std
::
make_shared
<
abstract
::
AnalysisEngine
>
(
abstract
::
GetPrimEvaluatorConstructors
(),
manager_
)),
input_
(
obj
),
...
...
@@ -218,31 +224,42 @@ Resource::~Resource() {
}
}
bool
Resource
::
IsTypeInMethodMap
(
const
TypeId
&
type
)
{
TypeId
type_id
=
NormalizeTypeId
(
type
);
const
MethodMap
&
method_map
=
GetMethodMap
();
auto
iter
=
method_map
.
find
(
static_cast
<
int
>
(
type_id
));
if
(
iter
!=
method_map
.
end
())
{
return
true
;
Any
GetMethodOrAttr
(
const
string
&
name
,
const
TypeId
&
type_id
,
const
BuiltInTypeMap
&
method_map
)
{
auto
type_method_map
=
method_map
.
find
(
static_cast
<
int
>
(
type_id
));
if
(
type_method_map
==
method_map
.
end
())
{
return
Any
();
}
return
false
;
auto
method
=
type_method_map
->
second
.
find
(
name
);
if
(
method
==
type_method_map
->
second
.
end
())
{
return
Any
();
}
return
method
->
second
;
}
Any
Resource
::
GetMethodPtr
(
const
TypeId
&
type
,
const
std
::
string
&
nam
e
)
{
bool
Resource
::
IsTypeInBuiltInMap
(
const
TypeId
&
typ
e
)
{
TypeId
type_id
=
NormalizeTypeId
(
type
);
const
Method
Map
&
method_map
=
GetMethodMap
();
const
BuiltInType
Map
&
method_map
=
GetMethodMap
();
auto
iter
=
method_map
.
find
(
static_cast
<
int
>
(
type_id
));
if
(
iter
==
method_map
.
end
())
{
MS_LOG
(
WARNING
)
<<
"Object type: "
<<
type_id
<<
" not in the method_map"
;
return
Any
();
const
BuiltInTypeMap
&
attr_map
=
GetAttrMap
();
iter
=
attr_map
.
find
(
static_cast
<
int
>
(
type_id
));
if
(
iter
==
attr_map
.
end
())
{
return
false
;
}
}
return
true
;
}
auto
iter_map
=
iter
->
second
.
find
(
name
);
if
(
iter_map
==
iter
->
second
.
end
())
{
MS_LOG
(
WARNING
)
<<
"Object type: "
<<
type_id
<<
" have no method: "
<<
name
;
return
Any
();
}
return
iter_map
->
second
;
Any
Resource
::
GetMethodPtr
(
const
TypeId
&
type
,
const
std
::
string
&
name
)
{
TypeId
type_id
=
NormalizeTypeId
(
type
);
const
BuiltInTypeMap
&
method_map
=
GetMethodMap
();
return
GetMethodOrAttr
(
name
,
type_id
,
method_map
);
}
Any
Resource
::
GetAttrPtr
(
const
TypeId
&
type
,
const
std
::
string
&
name
)
{
TypeId
type_id
=
NormalizeTypeId
(
type
);
const
BuiltInTypeMap
&
attr_map
=
GetAttrMap
();
return
GetMethodOrAttr
(
name
,
type_id
,
attr_map
);
}
void
Resource
::
Clean
()
{
...
...
mindspore/ccsrc/pipeline/jit/resource.h
浏览文件 @
b075674c
...
...
@@ -44,9 +44,11 @@ const char kOutput[] = "output";
class
InferenceResource
;
using
Method
Map
=
std
::
unordered_map
<
int
,
std
::
unordered_map
<
std
::
string
,
Any
>>
;
using
BuiltInType
Map
=
std
::
unordered_map
<
int
,
std
::
unordered_map
<
std
::
string
,
Any
>>
;
MethodMap
&
GetMethodMap
();
BuiltInTypeMap
&
GetMethodMap
();
BuiltInTypeMap
&
GetAttrMap
();
class
ResourceBase
{
public:
...
...
@@ -87,10 +89,12 @@ class Resource : public ResourceBase {
abstract
::
AnalysisEnginePtr
engine
()
{
return
engine_
;
}
static
bool
IsTypeIn
Method
Map
(
const
TypeId
&
type
);
static
bool
IsTypeIn
BuiltIn
Map
(
const
TypeId
&
type
);
static
Any
GetMethodPtr
(
const
TypeId
&
type
,
const
std
::
string
&
name
);
static
Any
GetAttrPtr
(
const
TypeId
&
type
,
const
std
::
string
&
name
);
const
py
::
object
&
input
()
const
{
return
input_
;
}
FuncGraphPtr
func_graph
()
const
{
return
func_graph_
;
}
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
b075674c
...
...
@@ -21,7 +21,6 @@
#include <algorithm>
#include <limits>
#include <mutex>
#include <set>
#include <string>
#include <utility>
...
...
@@ -31,10 +30,8 @@
#include "frontend/operator/prim_to_function.h"
#include "abstract/utils.h"
#include "utils/symbolic.h"
#include "./common.h"
#include "pipeline/jit/resource.h"
#include "pipeline/jit/parse/resolve.h"
#include "ir/tensor.h"
#include "utils/convert_utils.h"
#include "utils/context/ms_context.h"
#include "pipeline/jit/parse/data_converter.h"
...
...
@@ -64,7 +61,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimScalarToArray
,
{
InferImplScalarToArray
,
true
}},
{
prim
::
kPrimArrayToScalar
,
{
InferImplArrayToScalar
,
true
}},
{
prim
::
kPrimBroadcastShape
,
{
InferImplBroadCastShape
,
true
}},
{
prim
::
kPrimShape
,
{
InferImplShape
,
true
}},
{
prim
::
kPrimPack
,
{
InferImplPack
,
true
}},
// Structure
{
prim
::
kPrimMakeTuple
,
{
InferImplMakeTuple
,
true
}},
...
...
@@ -634,7 +630,7 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm
}
const
int
kResolveCaseUserDefineClass
=
1
;
const
int
kResolveCaseBuil
dinTypeMethod
=
2
;
const
int
kResolveCaseBuil
tInType
=
2
;
const
int
kResolveCaseFunction
=
3
;
int
GetResolveCase
(
const
TypePtr
&
data_type
)
{
MS_EXCEPTION_IF_NULL
(
data_type
);
...
...
@@ -643,8 +639,8 @@ int GetResolveCase(const TypePtr &data_type) {
}
// try method map, if not in method map, the data_type should be External type.
if
(
pipeline
::
Resource
::
IsTypeIn
Method
Map
(
data_type
->
type_id
()))
{
return
kResolveCaseBuil
dinTypeMethod
;
if
(
pipeline
::
Resource
::
IsTypeIn
BuiltIn
Map
(
data_type
->
type_id
()))
{
return
kResolveCaseBuil
tInType
;
}
return
kResolveCaseFunction
;
...
...
@@ -674,8 +670,10 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
manager
->
AddFuncGraph
(
func_graph
);
}
EvalResultPtr
StaticGetterInferred
(
const
ValuePtr
&
value
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
old_conf
)
{
enum
REQUIRE_TYPE
{
ATTR
,
METHOD
};
EvalResultPtr
StaticGetterInferred
(
const
ValuePtr
&
value
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
old_conf
,
REQUIRE_TYPE
require_type
=
REQUIRE_TYPE
::
METHOD
)
{
MS_EXCEPTION_IF_NULL
(
old_conf
);
AbstractBasePtr
abs_ptr
=
ToAbstract
(
value
,
AnalysisContext
::
DummyContext
(),
old_conf
);
...
...
@@ -701,6 +699,9 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
MS_EXCEPTION_IF_NULL
(
old_conf
);
FuncGraphPtr
func_graph
=
old_conf
->
node
()
->
func_graph
();
CNodePtr
new_cnode
=
func_graph
->
NewCNode
(
input
);
if
(
require_type
==
REQUIRE_TYPE
::
ATTR
)
{
new_cnode
=
func_graph
->
NewCNode
({
new_cnode
});
}
AnalysisEnginePtr
eng
=
old_conf
->
engine
();
AnfNodeConfigPtr
fn_conf
=
eng
->
MakeConfig
(
new_cnode
,
old_conf
->
context
());
return
eng
->
ForwardConfig
(
old_conf
,
fn_conf
);
...
...
@@ -781,9 +782,9 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
return
StaticGetterInferred
(
converted_v
,
data_conf
,
out_conf
);
}
EvalResultPtr
GetEvaluatedValueForBuiltinTypeMethod
(
const
AnalysisEnginePtr
&
engine
,
const
ValuePtr
&
item_v
,
const
TypePtr
&
data_type
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
EvalResultPtr
GetEvaluatedValueForBuiltinType
AttrOr
Method
(
const
AnalysisEnginePtr
&
engine
,
const
ValuePtr
&
item_v
,
const
TypePtr
&
data_type
,
const
ConfigPtr
&
data_conf
,
const
AnfNodeConfigPtr
&
out_conf
)
{
MS_EXCEPTION_IF_NULL
(
item_v
);
MS_EXCEPTION_IF_NULL
(
data_type
);
// The method maybe a Primitive or Composite
...
...
@@ -792,22 +793,29 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &eng
}
std
::
string
item_name
=
item_v
->
cast
<
StringImmPtr
>
()
->
value
();
Any
method
=
pipeline
::
Resource
::
GetMethodPtr
(
data_type
->
type_id
(),
item_name
);
if
(
method
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"Object type: "
<<
data_type
->
ToString
()
<<
" has no method: "
<<
item_name
;
REQUIRE_TYPE
require_type
=
REQUIRE_TYPE
::
METHOD
;
Any
require
=
pipeline
::
Resource
::
GetMethodPtr
(
data_type
->
type_id
(),
item_name
);
if
(
require
.
empty
())
{
require
=
pipeline
::
Resource
::
GetAttrPtr
(
data_type
->
type_id
(),
item_name
);
if
(
require
.
empty
())
{
MS_LOG
(
EXCEPTION
)
<<
"The object of type: "
<<
data_type
->
ToString
()
<<
" has no method or attr: "
<<
item_name
;
}
require_type
=
REQUIRE_TYPE
::
ATTR
;
}
ValuePtr
converted_v
=
nullptr
;
if
(
method
.
is
<
std
::
string
>
())
{
if
(
require
.
is
<
std
::
string
>
())
{
// composite registered in standard_method_map go to this branch
converted_v
=
prim
::
GetPythonOps
(
method
.
cast
<
std
::
string
>
());
AddToManager
(
engine
,
converted_v
->
cast
<
FuncGraphPtr
>
());
}
else
if
(
method
.
is
<
PrimitivePtr
>
())
{
converted_v
=
method
.
cast
<
PrimitivePtr
>
();
converted_v
=
prim
::
GetPythonOps
(
require
.
cast
<
std
::
string
>
());
if
(
!
converted_v
->
isa
<
Primitive
>
())
{
AddToManager
(
engine
,
converted_v
->
cast
<
FuncGraphPtr
>
());
}
}
else
if
(
require
.
is
<
PrimitivePtr
>
())
{
converted_v
=
require
.
cast
<
PrimitivePtr
>
();
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Expect to get string or PrimitivePtr from
method map, but got "
<<
method
.
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"Expect to get string or PrimitivePtr from
attr or method map, but got "
<<
require
.
ToString
();
}
return
StaticGetterInferred
(
converted_v
,
data_conf
,
out_conf
);
return
StaticGetterInferred
(
converted_v
,
data_conf
,
out_conf
,
require_type
);
}
EvalResultPtr
StaticGetter
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
...
...
@@ -831,8 +839,8 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
int
case_v
=
GetResolveCase
(
data_type
);
if
(
case_v
==
kResolveCaseUserDefineClass
)
{
return
GetEvaluatedValueForClassAttrOrMethod
(
engine
,
args_spec_list
,
item_value
,
data_conf
,
out_conf
);
}
else
if
(
case_v
==
kResolveCaseBuil
dinTypeMethod
)
{
return
GetEvaluatedValueForBuiltinTypeMethod
(
engine
,
item_value
,
data_type
,
data_conf
,
out_conf
);
}
else
if
(
case_v
==
kResolveCaseBuil
tInType
)
{
return
GetEvaluatedValueForBuiltinType
AttrOr
Method
(
engine
,
item_value
,
data_type
,
data_conf
,
out_conf
);
}
else
{
return
GetEvaluatedValueForNameSpaceString
(
engine
,
args_spec_list
,
out_conf
);
}
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
浏览文件 @
b075674c
...
...
@@ -218,10 +218,6 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBiasAddGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGelu
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplGeluGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplRelu
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplFakeBprop
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
...
...
@@ -246,8 +242,6 @@ AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const Primitiv
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplBroadCastShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplShape
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPack
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
b075674c
...
...
@@ -22,20 +22,21 @@ import copy
import
functools
import
itertools
import
numbers
import
numpy
as
np
from
..._checkparam
import
Validator
as
validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
...common.parameter
import
Parameter
from
..operations.math_ops
import
_infer_shape_reduce
from
.._utils
import
get_concat_offset
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
,
_run_op
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..operations.math_ops
import
_infer_shape_reduce
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
,
_run_op
from
..._c_expression
import
signature_dtype
as
sig_dtype
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
typing
from
..._checkparam
import
Rel
from
..._checkparam
import
Validator
as
validator
from
...common
import
dtype
as
mstype
from
...common.parameter
import
Parameter
from
...common.tensor
import
Tensor
class
_ScatterOp
(
PrimitiveWithInfer
):
...
...
@@ -415,7 +416,7 @@ class Reshape(PrimitiveWithInfer):
return
out
class
Shape
(
Primitive
):
class
Shape
(
Primitive
WithInfer
):
"""
Returns the shape of input tensor.
...
...
@@ -436,6 +437,13 @@ class Shape(Primitive):
def
__init__
(
self
):
"""init Shape"""
def
__infer__
(
self
,
x
):
validator
.
check_subclass
(
"input_x"
,
x
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
out
=
{
'shape'
:
(),
'dtype'
:
mstype
.
tuple_
,
'value'
:
tuple
(
x
[
'shape'
])}
return
out
class
Squeeze
(
PrimitiveWithInfer
):
"""
...
...
tests/ut/cpp/operator/ops_test.cc
浏览文件 @
b075674c
...
...
@@ -267,11 +267,6 @@ TEST_F(TestOps, BroadCastShapeTest) {
ASSERT_EQ
(
prim
->
name
(),
kPrimBroadcastShape
->
name
());
}
TEST_F
(
TestOps
,
ShapeTest
)
{
auto
prim
=
std
::
make_shared
<
Primitive
>
(
"Shape"
);
ASSERT_EQ
(
prim
->
name
(),
kPrimShape
->
name
());
}
TEST_F
(
TestOps
,
ArrayMapTest
)
{
auto
prim
=
std
::
make_shared
<
Primitive
>
(
"array_map"
);
ASSERT_EQ
(
prim
->
name
(),
kPrimArrayMap
->
name
());
...
...
tests/ut/cpp/pipeline/resource_test.cc
浏览文件 @
b075674c
...
...
@@ -36,23 +36,23 @@ class TestResource : public UT::Common {
void
TearDown
()
{}
};
TEST_F
(
TestResource
,
test_
standard_method
_map
)
{
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeInt
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeInt8
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeInt16
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeInt32
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeInt64
));
TEST_F
(
TestResource
,
test_
built_in_type
_map
)
{
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeInt
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeInt8
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeInt16
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeInt32
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeInt64
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeFloat
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeFloat16
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeFloat32
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeFloat64
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeFloat
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeFloat16
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeFloat32
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeFloat64
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeBool
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kNumberTypeUInt
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kObjectTypeTuple
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kObjectTypeList
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
Method
Map
(
kObjectTypeTensorType
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeBool
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kNumberTypeUInt
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kObjectTypeTuple
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kObjectTypeList
));
ASSERT_TRUE
(
true
==
Resource
::
IsTypeIn
BuiltIn
Map
(
kObjectTypeTensorType
));
MethodMap
&
map
=
GetMethodMap
();
for
(
auto
&
iter
:
map
)
{
...
...
tests/ut/cpp/pipeline/static_analysis/prim_test.cc
浏览文件 @
b075674c
...
...
@@ -467,24 +467,6 @@ TEST_F(TestPrim, test_env_add) {
ASSERT_TRUE
(
*
res
==
*
exp
);
}
TEST_F
(
TestPrim
,
test_shape
)
{
PrimitivePtr
shap
=
std
::
make_shared
<
Primitive
>
(
"Shape"
);
FuncGraphPtr
func_graph
=
MakeFuncGraph
(
shap
,
1
);
auto
a
=
UTPrimUtils
::
ArrayFloat64Of
({
2
,
3
});
AbstractBasePtrList
args_spec_list
=
{
a
};
AbstractTuplePtr
res
=
dyn_cast
<
AbstractTuple
>
(
engine_
->
Run
(
func_graph
,
args_spec_list
).
inferred
->
abstract
());
auto
ret
=
res
->
BuildValue
()
->
cast
<
ValueTuplePtr
>
()
->
value
();
std
::
vector
<
ValuePtr
>
element_list
=
{
MakeValue
(
2
),
MakeValue
(
3
)};
ASSERT_TRUE
(
ret
.
size
()
==
element_list
.
size
());
for
(
int
i
=
0
;
i
<
element_list
.
size
();
i
++
)
{
ASSERT_TRUE
(
*
ret
[
i
]
==
*
element_list
[
i
]);
}
}
TEST_F
(
TestPrim
,
test_relu
)
{
PrimitivePtr
relu
=
prim
::
kPrimRelu
;
relu
->
AddAttr
(
"T"
,
MakeValue
(
static_cast
<
int
>
(
kNumberTypeFloat64
)));
...
...
tests/ut/python/pipeline/parse/test_dtype_and_shape_as_attr.py
0 → 100644
浏览文件 @
b075674c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test dtype and shape as attr"""
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
dtype
as
mstype
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
test_dtype_and_shape_as_attr
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
):
shape
=
x
.
shape
dtype
=
x
.
dtype
return
shape
,
dtype
net
=
Net
()
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
ret
=
net
(
x
)
assert
ret
==
((
1
,
2
,
3
),
mstype
.
int32
)
def
test_dtype_and_shape_as_attr_to_new_tensor
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
value
):
super
(
Net
,
self
).
__init__
()
self
.
fill
=
P
.
Fill
()
self
.
value
=
value
def
construct
(
self
,
x
):
dtype
=
x
.
dtype
shape
=
x
.
shape
y
=
self
.
fill
(
dtype
,
shape
,
self
.
value
)
return
y
net
=
Net
(
2.2
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
float32
))
ret
=
net
(
x
)
assert
(
ret
.
asnumpy
()
==
(
np
.
zeros
([
1
,
2
,
3
],
np
.
float32
)
+
2.2
)).
all
()
def
test_type_not_have_the_attr
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
):
shape
=
x
.
shapes
return
shape
net
=
Net
()
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
with
pytest
.
raises
(
RuntimeError
)
as
ex
:
net
(
x
)
assert
"The object of type: Tensor[Int32] has no method or attr: shapes"
in
str
(
ex
.
value
)
def
test_type_not_have_the_method
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
):
shape
=
x
.
dtypes
()
return
shape
net
=
Net
()
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
with
pytest
.
raises
(
RuntimeError
)
as
ex
:
net
(
x
)
assert
"The object of type: Tensor[Int32] has no method or attr: dtypes"
in
str
(
ex
.
value
)
tests/ut/python/pipeline/parse/test_super.py
浏览文件 @
b075674c
...
...
@@ -20,7 +20,7 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
class
FatherNet
(
nn
.
Cell
):
...
...
@@ -92,7 +92,6 @@ class Net(nn.Cell):
def
test_single_super
():
single_net
=
SingleSubNet
(
2
,
3
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
single_net
(
x
,
y
)
...
...
@@ -100,7 +99,6 @@ def test_single_super():
def
test_mul_super
():
mul_net
=
MulSubNet
(
2
,
3
,
4
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
mul_net
(
x
,
y
)
...
...
@@ -108,7 +106,6 @@ def test_mul_super():
def
test_super_cell
():
net
=
Net
(
2
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
with
pytest
.
raises
(
RuntimeError
)
as
er
:
...
...
@@ -142,7 +139,6 @@ def test_single_super_in():
return
ret_father_construct
,
ret_father_test
,
ret_father_x
,
ret_sub_z
single_net_in
=
SingleSubNetIN
(
2
,
3
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
x
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
y
=
Tensor
(
np
.
ones
([
1
,
2
,
3
],
np
.
int32
))
single_net_in
(
x
,
y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录