Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8ac5672a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8ac5672a
编写于
7月 27, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add support for dynamic shape
上级
779c668a
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
258 addition
and
12 deletion
+258
-12
mindspore/ccsrc/frontend/operator/ops.h
mindspore/ccsrc/frontend/operator/ops.h
+2
-0
mindspore/ccsrc/frontend/operator/prim_arrays.cc
mindspore/ccsrc/frontend/operator/prim_arrays.cc
+42
-0
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+15
-1
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
+4
-0
mindspore/ccsrc/utils/convert_utils.cc
mindspore/ccsrc/utils/convert_utils.cc
+15
-3
mindspore/ccsrc/utils/convert_utils.h
mindspore/ccsrc/utils/convert_utils.h
+3
-1
mindspore/core/abstract/dshape.cc
mindspore/core/abstract/dshape.cc
+3
-0
mindspore/core/abstract/dshape.h
mindspore/core/abstract/dshape.h
+8
-2
mindspore/core/abstract/utils.cc
mindspore/core/abstract/utils.cc
+52
-1
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+20
-0
mindspore/ops/_utils/utils.py
mindspore/ops/_utils/utils.py
+4
-1
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+2
-1
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+25
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+23
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+40
-1
未找到文件。
mindspore/ccsrc/frontend/operator/ops.h
浏览文件 @
8ac5672a
...
...
@@ -113,6 +113,8 @@ inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransDat
inline
const
PrimitivePtr
kPrimNMSWithMask
=
std
::
make_shared
<
Primitive
>
(
"NMSWithMask"
);
inline
const
PrimitivePtr
kPrimPad
=
std
::
make_shared
<
Primitive
>
(
"Pad"
);
inline
const
PrimitivePtr
kPrimArgMaxWithValue
=
std
::
make_shared
<
Primitive
>
(
"ArgMaxWithValue"
);
inline
const
PrimitivePtr
kPrimUnique
=
std
::
make_shared
<
Primitive
>
(
"Unique"
);
inline
const
PrimitivePtr
kPrimUniqueGrad
=
std
::
make_shared
<
Primitive
>
(
"UniqueGrad"
);
// NN
inline
const
PrimitivePtr
kPrimFlatten
=
std
::
make_shared
<
Primitive
>
(
"Flatten"
);
...
...
mindspore/ccsrc/frontend/operator/prim_arrays.cc
浏览文件 @
8ac5672a
...
...
@@ -148,5 +148,47 @@ AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &pri
ret
->
set_shape
(
std
::
make_shared
<
Shape
>
(
shape
));
return
ret
;
}
AbstractBasePtr
InferImplUnique
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// inputs: a 1-d Tensor
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
1
);
AbstractTensorPtr
input
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
0
);
auto
shape
=
input
->
shape
();
if
(
shape
->
shape
().
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Rank of "
<<
op_name
<<
"'s input must be 1."
;
}
std
::
vector
<
int
>
ids_shape
=
{
Shape
::
SHP_ANY
};
std
::
vector
<
int
>
min_shape
=
{
1
};
std
::
vector
<
int
>
max_shape
=
shape
->
shape
();
auto
ids
=
std
::
make_shared
<
AbstractTensor
>
(
input
->
element
(),
std
::
make_shared
<
Shape
>
(
ids_shape
,
min_shape
,
max_shape
));
auto
ids_idx
=
std
::
make_shared
<
AbstractTensor
>
(
std
::
make_shared
<
Int
>
(
32
),
shape
->
shape
());
// outputs: ids, ids_idx
AbstractBasePtrList
elements
=
{
ids
,
ids_idx
};
return
std
::
make_shared
<
AbstractTuple
>
(
elements
);
}
AbstractBasePtr
InferImplUniqueGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// inputs: a 1-d Tensor
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractTuplePtr
dout
=
CheckArg
<
AbstractTuple
>
(
op_name
,
args_spec_list
,
0
);
CheckArgsSize
(
op_name
+
" dout"
,
dout
->
elements
(),
2
);
auto
ids
=
CheckArg
<
AbstractTensor
>
(
op_name
,
dout
->
elements
(),
0
);
auto
ids_idx
=
CheckArg
<
AbstractTensor
>
(
op_name
,
dout
->
elements
(),
1
);
if
(
ids
->
shape
()
->
shape
().
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Dims of dout[0] of "
<<
op_name
<<
"' input must be 1."
;
}
if
(
ids_idx
->
shape
()
->
shape
().
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"Dims of dout[1] of "
<<
op_name
<<
"' input must be 1."
;
}
// outputs: dx
return
std
::
make_shared
<
AbstractTensor
>
(
ids
->
element
(),
ids_idx
->
shape
());
}
}
// namespace abstract
}
// namespace mindspore
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
8ac5672a
...
...
@@ -23,6 +23,7 @@
#include <mutex>
#include <string>
#include <utility>
#include <unordered_set>
#include "frontend/operator/cc_implementations.h"
#include "frontend/operator/ops.h"
...
...
@@ -62,6 +63,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimArrayToScalar
,
{
InferImplArrayToScalar
,
true
}},
{
prim
::
kPrimBroadcastShape
,
{
InferImplBroadCastShape
,
true
}},
{
prim
::
kPrimPack
,
{
InferImplPack
,
true
}},
{
prim
::
kPrimUnique
,
{
InferImplUnique
,
true
}},
{
prim
::
kPrimUniqueGrad
,
{
InferImplUniqueGrad
,
true
}},
// Structure
{
prim
::
kPrimMakeTuple
,
{
InferImplMakeTuple
,
true
}},
{
prim
::
kPrimMakeList
,
{
InferImplMakeList
,
true
}},
...
...
@@ -389,6 +392,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if
(
abs_base
->
isa
<
AbstractTensor
>
())
{
auto
arg_tensor
=
dyn_cast
<
AbstractTensor
>
(
abs_base
);
dic
[
"shape"
]
=
arg_tensor
->
shape
()
->
shape
();
if
(
MsContext
::
GetInstance
()
->
execution_mode
()
==
kGraphMode
)
{
const
auto
&
min_shape
=
arg_tensor
->
shape
()
->
min_shape
();
const
auto
&
max_shape
=
arg_tensor
->
shape
()
->
max_shape
();
if
(
!
min_shape
.
empty
()
&&
!
max_shape
.
empty
())
{
dic
[
"min_shape"
]
=
min_shape
;
dic
[
"max_shape"
]
=
max_shape
;
}
}
dic
[
"dtype"
]
=
arg_tensor
->
BuildType
();
dic
[
"value"
]
=
BuildValue
(
arg_tensor
->
BuildValue
());
}
else
if
(
abs_base
->
isa
<
AbstractIndexedSlices
>
())
{
...
...
@@ -503,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
if
(
output
[
"value"
].
is_none
())
{
auto
out_shape
=
output
[
"shape"
];
auto
out_dtype
=
output
[
"dtype"
];
return
PyListDtype2AbstractTensor
(
out_shape
,
out_dtype
);
py
::
object
min_shape
=
output
.
contains
(
"min_shape"
)
?
(
py
::
object
)
output
[
"min_shape"
]
:
(
py
::
object
)
py
::
none
();
py
::
object
max_shape
=
output
.
contains
(
"max_shape"
)
?
(
py
::
object
)
output
[
"max_shape"
]
:
(
py
::
object
)
py
::
none
();
return
PyListDtype2AbstractTensor
(
out_shape
,
out_dtype
,
min_shape
,
max_shape
);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr
converted_ret
=
nullptr
;
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.h
浏览文件 @
8ac5672a
...
...
@@ -244,6 +244,10 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPack
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUnique
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplUniqueGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplMakeTuple
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
...
...
mindspore/ccsrc/utils/convert_utils.cc
浏览文件 @
8ac5672a
...
...
@@ -371,7 +371,8 @@ py::object VectorRefToPyData(const VectorRef &value_list) {
return
ret
;
}
AbstractBasePtr
PyListDtype2AbstractTensor
(
const
py
::
object
&
shape_obj
,
const
py
::
object
&
type_obj
)
{
AbstractBasePtr
PyListDtype2AbstractTensor
(
const
py
::
object
&
shape_obj
,
const
py
::
object
&
type_obj
,
const
py
::
object
&
min_shape
,
const
py
::
object
&
max_shape
)
{
if
((
py
::
isinstance
<
py
::
list
>
(
shape_obj
)
||
py
::
isinstance
<
py
::
tuple
>
(
shape_obj
))
&&
py
::
isinstance
<
Type
>
(
type_obj
))
{
auto
ret_vec
=
shape_obj
.
cast
<
std
::
vector
<
int
>>
();
auto
ret_dtype
=
type_obj
.
cast
<
TypePtr
>
();
...
...
@@ -382,12 +383,23 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
return
abs_scalar
;
}
AbstractBasePtr
tensor
=
nullptr
;
std
::
vector
<
int
>
min_shape_vec
;
std
::
vector
<
int
>
max_shape_vec
;
if
(
!
min_shape
.
is_none
())
{
min_shape_vec
=
min_shape
.
cast
<
std
::
vector
<
int
>>
();
}
if
(
!
max_shape
.
is_none
())
{
max_shape_vec
=
max_shape
.
cast
<
std
::
vector
<
int
>>
();
}
auto
ret_shape
=
std
::
make_shared
<
abstract
::
Shape
>
(
ret_vec
,
min_shape_vec
,
max_shape_vec
);
if
(
ret_dtype
->
isa
<
TensorType
>
())
{
auto
tensor_type
=
type_obj
.
cast
<
TensorTypePtr
>
();
MS_EXCEPTION_IF_NULL
(
tensor_type
);
tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
tensor_type
->
element
(),
ret_vec
);
auto
element
=
std
::
make_shared
<
abstract
::
AbstractScalar
>
(
kAnyValue
,
tensor_type
->
element
());
tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
element
,
ret_shape
);
}
else
{
tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
ret_dtype
,
ret_vec
);
auto
element
=
std
::
make_shared
<
abstract
::
AbstractScalar
>
(
kAnyValue
,
ret_dtype
);
tensor
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
element
,
ret_shape
);
}
return
tensor
;
}
else
if
(
py
::
isinstance
<
py
::
tuple
>
(
shape_obj
)
&&
py
::
isinstance
<
py
::
tuple
>
(
type_obj
))
{
...
...
mindspore/ccsrc/utils/convert_utils.h
浏览文件 @
8ac5672a
...
...
@@ -47,7 +47,9 @@ bool BaseRefToInt(const ValuePtr &v, int *value);
bool
ValueToBool
(
const
ValuePtr
&
in
,
bool
*
out
);
py
::
object
ValuePtrToPyData
(
const
ValuePtr
&
value
);
AbstractBasePtr
PyListDtype2AbstractTensor
(
const
py
::
object
&
shape_obj
,
const
py
::
object
&
type_obj
);
AbstractBasePtr
PyListDtype2AbstractTensor
(
const
py
::
object
&
shape_obj
,
const
py
::
object
&
type_obj
,
const
py
::
object
&
min_shape
=
py
::
none
(),
const
py
::
object
&
max_shape
=
py
::
none
());
bool
IsGraphOutputValueNodeOrParameter
(
const
AnfNodePtr
&
output
,
const
py
::
tuple
&
args
,
const
std
::
shared_ptr
<
py
::
object
>
&
ret_val
);
...
...
mindspore/core/abstract/dshape.cc
浏览文件 @
8ac5672a
...
...
@@ -67,6 +67,9 @@ std::string Shape::DumpText() const {
buffer
<<
"["
;
for
(
size_t
i
=
0
;
i
<
shape_
.
size
();
i
++
)
{
buffer
<<
(
i
>
0
?
", "
:
""
)
<<
shape_
[
i
];
if
(
shape_
[
i
]
==
SHP_ANY
&&
min_shape_
.
size
()
==
shape_
.
size
()
&&
max_shape_
.
size
()
==
shape_
.
size
())
{
buffer
<<
"_"
<<
min_shape_
[
i
]
<<
"^"
<<
max_shape_
[
i
];
}
}
buffer
<<
"]"
;
return
buffer
.
str
();
...
...
mindspore/core/abstract/dshape.h
浏览文件 @
8ac5672a
...
...
@@ -74,16 +74,22 @@ class Shape : public BaseShape {
(
void
)
std
::
transform
(
list
.
begin
(),
list
.
end
(),
std
::
back_inserter
(
shape_
),
[](
const
int64_t
&
value
)
{
return
static_cast
<
int
>
(
value
);
});
}
Shape
(
const
std
::
vector
<
int
>
&
list
,
const
std
::
vector
<
int
>
&
min_shape
,
const
std
::
vector
<
int
>
&
max_shape
)
:
shape_
(
list
),
min_shape_
(
min_shape
),
max_shape_
(
max_shape
)
{}
~
Shape
()
override
=
default
;
MS_DECLARE_PARENT
(
Shape
,
BaseShape
)
std
::
string
ToString
()
const
override
;
std
::
string
DumpText
()
const
override
;
bool
operator
==
(
const
BaseShape
&
other
)
const
override
;
BaseShapePtr
Clone
()
const
override
{
return
std
::
make_shared
<
Shape
>
(
shape_
);
}
BaseShapePtr
Clone
()
const
override
{
return
std
::
make_shared
<
Shape
>
(
shape_
,
min_shape_
,
max_shape_
);
}
void
Broaden
()
override
;
std
::
vector
<
int
>
&
shape
()
{
return
shape_
;
}
std
::
vector
<
int
>
&
min_shape
()
{
return
min_shape_
;
}
std
::
vector
<
int
>
&
max_shape
()
{
return
max_shape_
;
}
std
::
vector
<
int
>
shape_
;
// use SHP_ANY to implement the any shape in python
std
::
vector
<
int
>
shape_
;
// use SHP_ANY to implement the any shape in python
std
::
vector
<
int
>
min_shape_
;
// record mininum length for each dynamic dimention
std
::
vector
<
int
>
max_shape_
;
// record maximum length for each dynamic dimention
};
using
ShapePtr
=
std
::
shared_ptr
<
Shape
>
;
using
ShapePtrList
=
std
::
vector
<
ShapePtr
>
;
...
...
mindspore/core/abstract/utils.cc
浏览文件 @
8ac5672a
...
...
@@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
return
shape1
;
}
std
::
vector
<
int
>
dims
;
bool
has_dynamic_shape
=
false
;
dims
.
resize
(
shape1
->
shape
().
size
());
for
(
std
::
size_t
i
=
0
;
i
<
shape1
->
shape
().
size
();
i
++
)
{
if
(
shape1
->
shape
()[
i
]
==
shape2
->
shape
()[
i
])
{
dims
[
i
]
=
shape1
->
shape
()[
i
];
if
(
shape1
->
shape
()[
i
]
==
Shape
::
SHP_ANY
)
{
has_dynamic_shape
=
true
;
}
}
else
{
dims
[
i
]
=
Shape
::
SHP_ANY
;
has_dynamic_shape
=
true
;
}
}
return
std
::
make_shared
<
Shape
>
(
dims
);
if
(
!
has_dynamic_shape
)
{
return
std
::
make_shared
<
Shape
>
(
dims
);
}
// calculate dynamic shape
std
::
vector
<
int
>
min_dims
(
dims
.
size
());
std
::
vector
<
int
>
max_dims
(
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
if
(
dims
[
i
]
!=
Shape
::
SHP_ANY
)
{
min_dims
[
i
]
=
max_dims
[
i
]
=
dims
[
i
];
continue
;
}
if
(
shape1
->
shape
()[
i
]
!=
Shape
::
SHP_ANY
&&
shape2
->
shape
()[
i
]
!=
Shape
::
SHP_ANY
)
{
min_dims
[
i
]
=
std
::
min
(
shape1
->
shape
()[
i
],
shape2
->
shape
()[
i
]);
max_dims
[
i
]
=
std
::
max
(
shape1
->
shape
()[
i
],
shape2
->
shape
()[
i
]);
continue
;
}
if
(
shape1
->
shape
()[
i
]
==
Shape
::
SHP_ANY
&&
shape2
->
shape
()[
i
]
!=
Shape
::
SHP_ANY
)
{
if
(
shape1
->
min_shape
().
empty
()
||
shape1
->
max_shape
().
empty
())
{
MS_EXCEPTION
(
ValueError
)
<<
"Shape "
<<
shape1
->
ToString
()
<<
" has dynamic shape, but does not have min/max shape info."
;
}
min_dims
[
i
]
=
std
::
min
(
shape1
->
min_shape
()[
i
],
shape2
->
shape
()[
i
]);
max_dims
[
i
]
=
std
::
max
(
shape1
->
max_shape
()[
i
],
shape2
->
shape
()[
i
]);
continue
;
}
if
(
shape1
->
shape
()[
i
]
!=
Shape
::
SHP_ANY
&&
shape2
->
shape
()[
i
]
==
Shape
::
SHP_ANY
)
{
if
(
shape2
->
min_shape
().
empty
()
||
shape2
->
max_shape
().
empty
())
{
MS_EXCEPTION
(
ValueError
)
<<
"Shape "
<<
shape1
->
ToString
()
<<
" has dynamic shape, but does not have min/max shape info."
;
}
min_dims
[
i
]
=
std
::
min
(
shape1
->
shape
()[
i
],
shape2
->
min_shape
()[
i
]);
max_dims
[
i
]
=
std
::
max
(
shape1
->
shape
()[
i
],
shape2
->
max_shape
()[
i
]);
continue
;
}
// both shapes contains dynamic shape
if
(
shape1
->
min_shape
().
empty
()
||
shape1
->
max_shape
().
empty
())
{
MS_EXCEPTION
(
ValueError
)
<<
"Shape "
<<
shape1
->
ToString
()
<<
" has dynamic shape, but does not have min/max shape info."
;
}
if
(
shape2
->
min_shape
().
empty
()
||
shape2
->
max_shape
().
empty
())
{
MS_EXCEPTION
(
ValueError
)
<<
"Shape "
<<
shape2
->
ToString
()
<<
" has dynamic shape, but does not have min/max shape info."
;
}
min_dims
[
i
]
=
std
::
min
(
shape1
->
min_shape
()[
i
],
shape2
->
min_shape
()[
i
]);
max_dims
[
i
]
=
std
::
max
(
shape1
->
max_shape
()[
i
],
shape2
->
max_shape
()[
i
]);
}
return
std
::
make_shared
<
Shape
>
(
dims
,
min_dims
,
max_dims
);
}
AbstractBasePtr
AbstractJoin
(
const
AbstractBasePtrList
&
args_spec_list
)
{
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
8ac5672a
...
...
@@ -807,3 +807,23 @@ def get_bprop_trans_shape(self):
dx
=
op
(
dout
,
shape_op
(
x
))
return
(
dx
,
zeros_like
(
shape
))
return
bprop
@
bprop_getters
.
register
(
P
.
Unique
)
def
get_bprop_unique
(
self
):
"""Generate bprop for Unique"""
op
=
G
.
UniqueGrad
()
def
bprop
(
x
,
out
,
dout
):
dx
=
op
(
dout
,
out
)
return
(
dx
,)
return
bprop
@
bprop_getters
.
register
(
P
.
UnsortedSegmentSum
)
def
get_bprop_unsorted_segment_sum
(
self
):
"""Generate bprop for UnsortedSegmentSum"""
op
=
G
.
UnsortedSegmentSumGrad
()
def
bprop
(
x
,
segment_ids
,
num_segments
,
out
,
dout
):
dx
=
op
(
dout
,
segment_ids
)
return
(
dx
,
zeros_like
(
segment_ids
),
zeros_like
(
num_segments
))
return
bprop
mindspore/ops/_utils/utils.py
浏览文件 @
8ac5672a
...
...
@@ -82,5 +82,8 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
if
j
!=
axis
and
v
[
j
]
!=
x_shp
[
0
][
j
]:
raise
ValueError
(
f
"For
\'
{
prim_name
}
\'
element
{
i
}
shape in input can not concat with first element"
)
offset
.
append
(
all_shp
)
all_shp
+=
v
[
axis
]
if
all_shp
==
-
1
or
v
[
axis
]
==
-
1
:
all_shp
=
-
1
else
:
all_shp
+=
v
[
axis
]
return
offset
,
all_shp
,
axis
mindspore/ops/operations/__init__.py
浏览文件 @
8ac5672a
...
...
@@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Squeeze
,
StridedSlice
,
Tile
,
TensorScatterUpdate
,
Transpose
,
TruncatedNormal
,
TupleToArray
,
UnsortedSegmentMin
,
UnsortedSegmentProd
,
UnsortedSegmentSum
,
SpaceToDepth
,
DepthToSpace
,
SpaceToBatch
,
BatchToSpace
,
SpaceToBatchND
,
BatchToSpaceND
,
BroadcastTo
,
InplaceUpdate
,
ReverseSequence
,
EmbeddingLookup
)
SpaceToBatchND
,
BatchToSpaceND
,
BroadcastTo
,
InplaceUpdate
,
ReverseSequence
,
EmbeddingLookup
,
Unique
)
from
.comm_ops
import
(
AllGather
,
AllReduce
,
_AlltoAll
,
ReduceScatter
,
Broadcast
,
_MirrorOperator
,
ReduceOp
,
_VirtualDataset
,
_VirtualDiv
,
_GetTensorSlice
,
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
8ac5672a
...
...
@@ -491,6 +491,31 @@ class FusedBatchNormGrad(Primitive):
raise
NotImplementedError
class
UniqueGrad
(
Primitive
):
"""Gradients of Unique operation."""
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'dy'
,
'y'
],
outputs
=
[
'dx'
])
def
__call__
(
self
,
dy
,
x
,
scale
,
save_mean
,
save_inv_variance
):
raise
NotImplementedError
class
UnsortedSegmentSumGrad
(
PrimitiveWithInfer
):
"""Gradients of UnsortedSegmentSum operation."""
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'grads'
,
'ids'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
grads
,
ids
):
return
ids
+
grads
[
len
(
ids
):]
def
infer_dtype
(
self
,
grads
,
ids
):
return
grads
class
BNTrainingReduceGrad
(
PrimitiveWithInfer
):
"""Gradients of FusedBatchNorm operation."""
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
8ac5672a
...
...
@@ -27,7 +27,7 @@ import numpy as np
from
.._utils
import
get_concat_offset
from
..operations.math_ops
import
_infer_shape_reduce
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
,
_run_op
from
..primitive
import
Primitive
,
Primitive
WithInfer
,
prim_attr_register
,
_run_op
from
..._c_expression
import
signature_dtype
as
sig_dtype
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._c_expression
import
signature_rw
as
sig_rw
...
...
@@ -556,6 +556,28 @@ class Transpose(PrimitiveWithInfer):
return
out
class
Unique
(
Primitive
):
"""
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
tensor corresponding to the output unique tensor.
Inputs:
- **x** (Tensor) - The input tensor.
Outputs:
Tuple, containing tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor
containing indices of elements in the input coressponding to the output tensor.
Examples:
>>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.float32)
>>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.float32))
"""
@
prim_attr_register
def
__init__
(
self
):
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'output'
])
class
GatherV2
(
PrimitiveWithInfer
):
"""
Returns a slice of input tensor based on the specified indices and axis.
...
...
mindspore/ops/primitive.py
浏览文件 @
8ac5672a
...
...
@@ -20,6 +20,7 @@ import copy
from
mindspore.common.api
import
_wrap_func
from
mindspore.common
import
Parameter
from
mindspore.common._register_for_tensor
import
tensor_operator_registry
from
mindspore
import
context
from
.._c_expression
import
Primitive_
,
real_run_op
,
prim_type
from
.._c_expression
import
signature_rw
as
sig_rw
from
.._c_expression
import
signature_kind
as
sig_kind
...
...
@@ -138,6 +139,8 @@ class Primitive(Primitive_):
return
self
def
__getattr__
(
self
,
item
):
if
item
==
'infer_dynamic_shape'
:
return
None
if
item
in
super
().
get_attr_dict
():
return
super
().
get_attr_dict
()[
item
]
if
item
in
self
.
attrs
:
...
...
@@ -282,13 +285,49 @@ class PrimitiveWithInfer(Primitive):
def
__infer__
(
self
,
*
args
):
"""Infer shape, type, and value at the same time by using dictionary as arguments."""
is_graph_mode
=
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
fn_infer_dynamic_shape
=
getattr
(
self
,
'infer_dynamic_shape'
,
None
)
if
is_graph_mode
and
fn_infer_dynamic_shape
is
not
None
:
out
=
fn_infer_dynamic_shape
(
*
args
)
tracks
=
[
'dtype'
,
'value'
]
for
track
in
tracks
:
fn
=
getattr
(
self
,
'infer_'
+
track
)
# fn may return None
out
[
track
]
=
fn
(
*
(
x
[
track
]
for
x
in
args
))
return
out
tracks
=
[
'dtype'
,
'shape'
,
'value'
]
out
=
{}
for
track
in
tracks
:
fn
=
getattr
(
self
,
'infer_'
+
track
)
# fn may return None
out
[
track
]
=
fn
(
*
(
x
[
track
]
for
x
in
args
))
return
out
# in non-graph_mode, it is not necessary to infer min/max shape
if
not
is_graph_mode
:
return
out
def
get_specified_shape
(
elems
,
attr
):
has_specified_shape
=
False
ret_vals
=
[]
for
elem
in
elems
:
if
attr
in
elem
:
has_specified_shape
=
True
ret_vals
.
append
(
elem
[
attr
])
else
:
ret_vals
.
append
(
elem
[
'shape'
])
return
has_specified_shape
,
tuple
(
ret_vals
)
has_min_shape
,
min_shapes
=
get_specified_shape
(
args
,
'min_shape'
)
has_max_shape
,
max_shapes
=
get_specified_shape
(
args
,
'max_shape'
)
if
not
(
has_min_shape
or
has_max_shape
):
return
out
if
has_min_shape
and
has_max_shape
:
fn_infer_shape
=
getattr
(
self
,
'infer_shape'
)
out
[
'min_shape'
]
=
fn_infer_shape
(
*
min_shapes
)
out
[
'max_shape'
]
=
fn_infer_shape
(
*
max_shapes
)
return
out
raise
ValueError
(
'Input args has invalid dynamic shape, args info: {args}'
)
def
prim_attr_register
(
fn
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录