Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
75fec82b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
75fec82b
编写于
4月 14, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
resolve pynative operator issue
上级
5ed799d7
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
208 addition
and
106 deletion
+208
-106
mindspore/_extends/builtin_operations.py
mindspore/_extends/builtin_operations.py
+6
-2
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+7
-5
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+66
-21
mindspore/common/parameter.py
mindspore/common/parameter.py
+13
-5
mindspore/common/tensor.py
mindspore/common/tensor.py
+29
-14
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+1
-1
mindspore/ops/_utils/__init__.py
mindspore/ops/_utils/__init__.py
+2
-2
mindspore/ops/_utils/utils.py
mindspore/ops/_utils/utils.py
+28
-1
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+1
-2
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+28
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-52
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+3
-0
tests/ut/python/ir/test_tensor.py
tests/ut/python/ir/test_tensor.py
+22
-0
tests/vm_impl/array_ops_vm_impl.py
tests/vm_impl/array_ops_vm_impl.py
+1
-1
未找到文件。
mindspore/_extends/builtin_operations.py
浏览文件 @
75fec82b
...
...
@@ -125,7 +125,7 @@ def list_len(x):
return
len
(
x
)
# only used in PyNative mode
s
# only used in PyNative mode
def
partial
(
*
args
):
"""Implement `partial`."""
func
=
args
[
0
].
__call__
...
...
@@ -133,10 +133,14 @@ def partial(*args):
return
partial_func
# only used in PyNative mode
s
# only used in PyNative mode
def
depend
(
value
,
expr
):
return
value
# only used in PyNative mode
def
make_ref
(
key
,
value
,
ref
):
return
value
def
scalar_cast
(
x
,
t
):
"""Implement scalar_cast."""
...
...
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
75fec82b
...
...
@@ -616,18 +616,20 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
return
ExecDFGraph
(
info_
,
args
,
phase_s
);
}
#else
if
(
backend
==
"ge"
)
{
std
::
shared_ptr
<
py
::
object
>
ret_val
=
std
::
make_shared
<
py
::
object
>
();
if
(
backend
==
"
ms"
||
backend
==
"
ge"
)
{
auto
ret_val
=
std
::
make_shared
<
py
::
object
>
();
if
(
info_
.
count
(
phase_s
)
!=
0
&&
info_
[
phase_s
]
->
func_graph
!=
nullptr
)
{
if
(
IsGraphOutputValueNodeOrParameter
(
info_
[
phase_s
]
->
func_graph
->
output
(),
args
,
ret_val
))
{
return
*
ret_val
;
}
}
if
(
backend
==
"ge"
)
{
if
(
args
.
size
()
>
0
)
{
return
args
[
0
];
}
return
args
;
}
}
#endif
std
::
size_t
full_arg_size
=
ArgListSize
(
phase_s
);
if
(
size
>
full_arg_size
)
{
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
75fec82b
...
...
@@ -20,11 +20,13 @@
#include <map>
#include <set>
#include <unordered_set>
#include <algorithm>
#include "utils/any.h"
#include "utils/utils.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
#include "operator/composite/do_signature.h"
#include "pipeline/parse/data_converter.h"
#include "pipeline/static_analysis/prim.h"
#include "session/session_factory.h"
...
...
@@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) {
return
converted_ret
;
}
py
::
tuple
ConvertInputs
(
const
PrimitivePyPtr
&
prim
,
const
py
::
tuple
&
py_args
)
{
auto
signature
=
prim
->
signatures
();
std
::
vector
<
SignatureEnumDType
>
dtypes
;
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
int
empty_dtype_count
=
std
::
count
(
dtypes
.
begin
(),
dtypes
.
end
(),
SignatureEnumDType
::
kDTypeEmptyDefaultValue
);
if
(
dtypes
.
size
()
==
0
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
return
py_args
;
}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indexs
;
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
auto
it
=
type_indexs
.
find
(
dtypes
[
i
]);
if
(
it
==
type_indexs
.
end
())
{
(
void
)
type_indexs
.
insert
(
std
::
make_pair
(
dtypes
[
i
],
std
::
vector
<
size_t
>
{
i
}));
}
else
{
it
->
second
.
push_back
(
i
);
}
}
std
::
map
<
SignatureEnumDType
,
size_t
>
dst_type
;
for
(
auto
it
=
type_indexs
.
begin
();
it
!=
type_indexs
.
end
();
(
void
)
++
it
)
{
auto
type
=
it
->
first
;
auto
indexs
=
it
->
second
;
if
(
indexs
.
size
()
<
2
)
{
continue
;
}
size_t
m_index
=
indexs
[
0
];
for
(
size_t
i
=
1
;
i
<
indexs
.
size
();
++
i
)
{
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
indexs
[
i
]]))
{
m_index
=
indexs
[
i
];
}
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
m_index
));
}
py
::
tuple
py_inputs
(
py_args
.
size
());
for
(
size_t
i
=
0
;
i
<
py_args
.
size
();
++
i
)
{
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
!=
dst_type
.
end
()
&&
it
->
second
!=
i
&&
(
py
::
isinstance
<
py
::
int_
>
(
py_args
[
i
])
||
py
::
isinstance
<
py
::
float_
>
(
py_args
[
i
])))
{
auto
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
py_args
[
it
->
second
]);
if
(
py
::
isinstance
<
py
::
int_
>
(
py_args
[
i
]))
{
py_inputs
[
i
]
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
int_
>
(
py_args
[
i
]),
tensor_ptr
->
Dtype
());
}
else
{
py_inputs
[
i
]
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
float_
>
(
py_args
[
i
]),
tensor_ptr
->
Dtype
());
}
continue
;
}
py_inputs
[
i
]
=
py_args
[
i
];
}
return
py_inputs
;
}
void
PynativeInfer
(
const
PrimitivePyPtr
&
prim
,
const
py
::
tuple
&
py_args
,
OpExecInfo
*
const
op_exec_info
)
{
size_t
size
=
py_args
.
size
();
AbstractBasePtrList
args_spec_list
;
...
...
@@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
auto
op_exec_info
=
std
::
make_shared
<
OpExecInfo
>
();
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
op_exec_info
->
op_name
=
py
::
cast
<
std
::
string
>
(
args
[
PY_NAME
]);
if
(
py
::
isinstance
<
py
::
none
>
(
args
[
PY_PRIM
]))
{
py
::
module
ops_mod
=
py
::
module
::
import
(
"mindspore.ops.operations"
);
py
::
object
py_primitive
=
ops_mod
.
attr
(
op_exec_info
->
op_name
.
c_str
())();
op_exec_info
->
py_primitive
=
py
::
cast
<
PrimitivePyPtr
>
(
py_primitive
);
py
::
dict
none_attrs
=
py
::
dict
();
op_exec_info
->
op_attrs
=
none_attrs
;
}
else
{
PrimitivePyPtr
prim
=
py
::
cast
<
PrimitivePyPtr
>
(
args
[
PY_PRIM
]);
auto
prim
=
py
::
cast
<
PrimitivePyPtr
>
(
args
[
PY_PRIM
]);
auto
pyobj
=
prim
->
GetPyObj
();
if
(
pyobj
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"pyobj is empty"
;
}
py
::
tuple
py_args
=
args
[
PY_INPUTS
]
;
py
::
tuple
py_args
=
ConvertInputs
(
prim
,
args
[
PY_INPUTS
])
;
// use python infer method
if
(
ignore_infer_prim
.
find
(
op_exec_info
->
op_name
)
==
ignore_infer_prim
.
end
())
{
PynativeInfer
(
prim
,
py_args
,
op_exec_info
.
get
());
}
op_exec_info
->
py_primitive
=
prim
;
op_exec_info
->
op_attrs
=
py
::
getattr
(
args
[
PY_PRIM
],
"attrs"
);
}
op_exec_info
->
op_inputs
=
args
[
PY_INPUTS
];
op_exec_info
->
op_inputs
=
py_args
;
op_exec_info
->
inputs_mask
=
args
[
PY_INPUT_MASK
];
if
(
op_exec_info
->
op_inputs
.
size
()
!=
op_exec_info
->
inputs_mask
.
size
())
{
MS_LOG
(
ERROR
)
<<
"
"
<<
op_exec_info
->
op_name
<<
" op_
inputs size not equal op_mask"
;
MS_LOG
(
ERROR
)
<<
"
op:"
<<
op_exec_info
->
op_name
<<
"
inputs size not equal op_mask"
;
return
nullptr
;
}
return
op_exec_info
;
...
...
mindspore/common/parameter.py
浏览文件 @
75fec82b
...
...
@@ -14,7 +14,7 @@
# ============================================================================
"""Parameter for cell."""
from
copy
import
copy
from
copy
import
copy
,
deepcopy
import
numpy
as
np
from
.initializer
import
initializer
from
.tensor
import
Tensor
...
...
@@ -156,16 +156,24 @@ class Parameter:
return
self
.
default_input
def
__add__
(
self
,
other
):
return
self
.
default_input
+
other
res
=
deepcopy
(
self
)
res
.
default_input
=
res
.
default_input
+
other
return
res
def
__sub__
(
self
,
other
):
return
self
.
default_input
-
other
res
=
deepcopy
(
self
)
res
.
default_input
=
res
.
default_input
-
other
return
res
def
__mul__
(
self
,
other
):
return
self
.
default_input
*
other
res
=
deepcopy
(
self
)
res
.
default_input
=
res
.
default_input
*
other
return
res
def
__truediv__
(
self
,
other
):
return
self
.
default_input
/
other
res
=
deepcopy
(
self
)
res
.
default_input
=
res
.
default_input
/
other
return
res
def
set_parameter_data
(
self
,
data
):
if
isinstance
(
data
,
(
Tensor
,
list
,
int
,
float
,
...
...
mindspore/common/tensor.py
浏览文件 @
75fec82b
...
...
@@ -70,45 +70,60 @@ class Tensor(Tensor_):
return
str
(
self
.
__str__
())
def
__add__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
raise
TypeError
(
"input_data must be a tensor"
)
check_type
(
'tensor input_data'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
return
out
def
__mul__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
raise
TypeError
(
"input_data must be a tensor"
)
check_type
(
'tensor input_data'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
self
,
other
)
return
out
def
__neg__
(
self
):
return
Tensor
(
-
self
.
asnumpy
())
def
__iadd__
(
self
,
other
):
out
=
self
.
__add__
(
other
)
return
out
def
__radd__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
other
,
self
)
return
out
def
__imul__
(
self
,
other
):
out
=
self
.
__mul__
(
other
)
return
out
def
__rmul__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__mul__'
)(
other
,
self
)
return
out
def
__truediv__
(
self
,
other
):
if
isinstance
(
other
,
(
int
,
float
)):
other_tensor
=
Tensor
(
other
,
self
.
dtype
()
)
elif
isinstance
(
other
,
Tensor
):
other_tensor
=
other
else
:
raise
TypeError
(
"unsupported type for div operation"
)
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
self
,
other_tensor
)
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
self
,
other
)
return
out
def
__rtruediv__
(
self
,
other
)
:
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
)
)
out
=
tensor_operator_registry
.
get
(
'__div__'
)(
other
,
self
)
return
out
def
__sub__
(
self
,
other
):
if
not
isinstance
(
other
,
Tensor
):
raise
TypeError
(
"input_data must be a tensor"
)
out
=
self
.
__add__
(
Tensor
(
-
other
.
asnumpy
()))
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
self
.
__add__
(
-
other
)
return
out
def
__isub__
(
self
,
other
):
out
=
self
.
__sub__
(
other
)
return
out
def
__rsub__
(
self
,
other
):
check_type
(
'tensor operation input'
,
other
,
(
Tensor
,
float
,
int
))
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
other
,
Tensor
(
-
self
.
asnumpy
()))
return
out
def
__str__
(
self
):
if
self
.
dtype
()
==
mstype
.
type_none
:
return
"Unknown Tensor type!"
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
75fec82b
...
...
@@ -191,7 +191,7 @@ def get_bprop_concat(self):
def
bprop
(
x
,
out
,
dout
):
dx
=
()
out_offset
=
P
.
ConcatOffset
(
F
.
tuple_len
(
x
),
axis
)(
x
)
out_offset
=
G
.
ConcatOffset
(
F
.
tuple_len
(
x
),
axis
)(
x
)
for
i
in
range
(
F
.
tuple_len
(
x
)):
slice_out
=
P
.
Slice
()(
dout
,
out_offset
[
i
],
shape_op
(
x
[
i
]))
dx
=
dx
+
(
slice_out
,)
...
...
mindspore/ops/_utils/__init__.py
浏览文件 @
75fec82b
...
...
@@ -14,6 +14,6 @@
# ============================================================================
"""ops utils."""
from
.
broadcast
import
_get_broadcast_shape
from
.
utils
import
_get_broadcast_shape
,
_get_concat_offset
__all__
=
[
'_get_broadcast_shape'
]
__all__
=
[
'_get_broadcast_shape'
,
'_get_concat_offset'
]
mindspore/ops/_utils/
broadcast
.py
→
mindspore/ops/_utils/
utils
.py
浏览文件 @
75fec82b
...
...
@@ -13,8 +13,11 @@
# limitations under the License.
# ============================================================================
"""
broadcast
"""
"""
utils for operator
"""
from
..._checkparam
import
ParamValidator
as
validator
from
..._checkparam
import
Rel
from
...common
import
dtype
as
mstype
def
_get_broadcast_shape
(
x_shape
,
y_shape
,
prim_name
):
"""
...
...
@@ -57,3 +60,27 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
broadcast_shape_front
=
y_shape
[
0
:
y_len
-
length
]
if
length
==
x_len
else
x_shape
[
0
:
x_len
-
length
]
broadcast_shape
=
broadcast_shape_front
+
broadcast_shape_back
return
broadcast_shape
def
_get_concat_offset
(
x_shp
,
x_type
,
axis
):
"""for concat and concatoffset check args and compute offset"""
validator
.
check_type
(
"shape"
,
x_shp
,
[
tuple
])
validator
.
check_integer
(
"len of input_x shape"
,
len
(
x_shp
),
0
,
Rel
.
GT
)
validator
.
check_subclass
(
"shape0"
,
x_type
[
0
],
mstype
.
tensor
)
validator
.
check_integer
(
"len of input_x0 shape"
,
len
(
x_shp
[
0
]),
0
,
Rel
.
GT
)
rank_base
=
len
(
x_shp
[
0
])
validator
.
check_int_range
(
'axis'
,
axis
,
-
rank_base
-
1
,
rank_base
,
Rel
.
INC_BOTH
)
if
axis
<
0
:
axis
=
axis
+
rank_base
all_shp
=
x_shp
[
0
][
axis
]
offset
=
[
0
,]
for
i
in
range
(
1
,
len
(
x_shp
)):
v
=
x_shp
[
i
]
validator
.
check
(
'len of x_shp[%d]'
%
i
,
len
(
v
),
'len of base'
,
len
(
x_shp
[
0
]))
validator
.
check
(
'x_type[%d]'
%
i
,
x_type
[
i
],
'base'
,
x_type
[
0
])
for
j
in
range
(
rank_base
):
if
j
!=
axis
and
v
[
j
]
!=
x_shp
[
0
][
j
]:
raise
ValueError
(
"Concat evaluator element %d shape in input can not concat with first element"
%
i
)
offset
.
append
(
all_shp
)
all_shp
+=
v
[
axis
]
return
offset
,
all_shp
,
axis
mindspore/ops/operations/__init__.py
浏览文件 @
75fec82b
...
...
@@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build nerual networks or computing functions.
"""
from
.array_ops
import
(
Argmax
,
Argmin
,
Cast
,
Concat
Offset
,
Concat
,
Pack
,
Unpack
,
from
.array_ops
import
(
Argmax
,
Argmin
,
Cast
,
Concat
,
Pack
,
Unpack
,
Diag
,
DiagPart
,
DType
,
ExpandDims
,
Eye
,
Fill
,
GatherNd
,
GatherV2
,
InvertPermutation
,
IsInstance
,
IsSubClass
,
ArgMaxWithValue
,
OnesLike
,
ZerosLike
,
...
...
@@ -200,7 +200,6 @@ __all__ = [
'LogicalOr'
,
'Size'
,
'DepthwiseConv2dNative'
,
'ConcatOffset'
,
'UnsortedSegmentSum'
,
"AllGather"
,
"AllReduce"
,
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
75fec82b
...
...
@@ -20,6 +20,7 @@ from ..._c_expression import signature_kind as sig_kind
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..._checkparam
import
ParamValidator
as
validator
from
..._checkparam
import
Rel
,
check_int_positive
,
check_bool
from
.._utils
import
_get_concat_offset
from
...common
import
dtype
as
mstype
...
...
@@ -107,6 +108,33 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
validator
.
check_two_types_same
(
'x_type'
,
x_type
,
'weight_type'
,
weight_type
)
return
x_type
class
ConcatOffset
(
PrimitiveWithInfer
):
"""primitive for computing Concat's gradient."""
@
prim_attr_register
def
__init__
(
self
,
N
=
2
,
axis
=
0
):
"""init ConcatOffset"""
def
__infer__
(
self
,
input_x
):
axis
=
self
.
axis
x_shp
=
input_x
[
'shape'
]
x_type
=
input_x
[
'dtype'
]
offset
,
_
,
axis
=
_get_concat_offset
(
x_shp
,
x_type
,
axis
)
self
.
add_prim_attr
(
'T'
,
x_type
[
0
].
element_type
())
offset_values
=
[]
for
i
in
range
(
len
(
x_shp
)):
values
=
[]
for
j
in
range
(
len
(
x_shp
[
0
])):
value
=
0
if
j
==
axis
:
value
=
offset
[
i
]
values
.
append
(
value
)
offset_values
.
append
(
tuple
(
values
))
out
=
{
'shape'
:
None
,
'dtype'
:
None
,
'value'
:
tuple
(
offset_values
)}
return
out
class
Conv2DBackpropFilter
(
PrimitiveWithInfer
):
"""
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
75fec82b
...
...
@@ -29,6 +29,7 @@ from ..._checkparam import Rel
from
...common
import
dtype
as
mstype
from
...common.tensor
import
Tensor
from
..operations.math_ops
import
_infer_shape_reduce
from
.._utils
import
_get_concat_offset
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
def
_check_infer_attr_reduce
(
axis
,
keep_dims
):
...
...
@@ -1275,30 +1276,6 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
return
out
def
_get_concat_offset
(
x_shp
,
x_type
,
axis
):
"""for concat and concatoffset check args and compute offset"""
validator
.
check_type
(
"shape"
,
x_shp
,
[
tuple
])
validator
.
check_integer
(
"len of input_x shape"
,
len
(
x_shp
),
0
,
Rel
.
GT
)
validator
.
check_subclass
(
"shape0"
,
x_type
[
0
],
mstype
.
tensor
)
validator
.
check_integer
(
"len of input_x0 shape"
,
len
(
x_shp
[
0
]),
0
,
Rel
.
GT
)
rank_base
=
len
(
x_shp
[
0
])
validator
.
check_int_range
(
'axis'
,
axis
,
-
rank_base
-
1
,
rank_base
,
Rel
.
INC_BOTH
)
if
axis
<
0
:
axis
=
axis
+
rank_base
all_shp
=
x_shp
[
0
][
axis
]
offset
=
[
0
,]
for
i
in
range
(
1
,
len
(
x_shp
)):
v
=
x_shp
[
i
]
validator
.
check
(
'len of x_shp[%d]'
%
i
,
len
(
v
),
'len of base'
,
len
(
x_shp
[
0
]))
validator
.
check
(
'x_type[%d]'
%
i
,
x_type
[
i
],
'base'
,
x_type
[
0
])
for
j
in
range
(
rank_base
):
if
j
!=
axis
and
v
[
j
]
!=
x_shp
[
0
][
j
]:
raise
ValueError
(
"Concat evaluator element %d shape in input can not concat with first element"
%
i
)
offset
.
append
(
all_shp
)
all_shp
+=
v
[
axis
]
return
offset
,
all_shp
,
axis
class
Concat
(
PrimitiveWithInfer
):
r
"""
Concat tensor in specified axis.
...
...
@@ -1531,34 +1508,6 @@ class Slice(PrimitiveWithInfer):
'value'
:
None
}
class
ConcatOffset
(
PrimitiveWithInfer
):
"""primitive for computing Concat's gradient."""
@
prim_attr_register
def
__init__
(
self
,
N
=
2
,
axis
=
0
):
"""init ConcatOffset"""
def
__infer__
(
self
,
input_x
):
axis
=
self
.
axis
x_shp
=
input_x
[
'shape'
]
x_type
=
input_x
[
'dtype'
]
offset
,
_
,
axis
=
_get_concat_offset
(
x_shp
,
x_type
,
axis
)
self
.
add_prim_attr
(
'T'
,
x_type
[
0
].
element_type
())
offset_values
=
[]
for
i
in
range
(
len
(
x_shp
)):
values
=
[]
for
j
in
range
(
len
(
x_shp
[
0
])):
value
=
0
if
j
==
axis
:
value
=
offset
[
i
]
values
.
append
(
value
)
offset_values
.
append
(
tuple
(
values
))
out
=
{
'shape'
:
None
,
'dtype'
:
None
,
'value'
:
tuple
(
offset_values
)}
return
out
class
Select
(
PrimitiveWithInfer
):
r
"""
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
75fec82b
...
...
@@ -271,3 +271,6 @@ class MakeRefKey(Primitive):
@
prim_attr_register
def
__init__
(
self
,
tag
):
validator
.
check_type
(
'tag'
,
tag
,
(
str
,))
def
__call__
(
self
):
pass
tests/ut/python/ir/test_tensor.py
浏览文件 @
75fec82b
...
...
@@ -24,6 +24,7 @@ import pytest
import
mindspore
as
ms
import
mindspore.common.api
as
me
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
..ut_filter
import
non_graph_engine
...
...
@@ -396,3 +397,24 @@ def test_tensor_dtype_fp32_to_bool():
input
=
ms
.
Tensor
(
input
)
input_me
=
ms
.
Tensor
(
input
,
dtype
=
ms
.
bool_
)
def
test_tensor_operation
():
x
=
Tensor
(
np
.
ones
((
3
,
3
))
*
4
)
res
=
x
+
1
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
5
)
res
=
1
+
x
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
5
)
res
=
x
-
2
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
res
=
6
-
x
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
res
=
x
*
3
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
12
)
res
=
3
*
x
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
12
)
res
=
x
/
2
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
res
=
8
/
x
assert
np
.
all
(
res
.
asnumpy
()
==
np
.
ones
((
3
,
3
))
*
2
)
with
pytest
.
raises
(
TypeError
):
res
=
x
*
(
2
,
3
)
tests/vm_impl/array_ops_vm_impl.py
浏览文件 @
75fec82b
...
...
@@ -190,7 +190,7 @@ def vm_impl_slice(self):
return
vm_impl
@
vm_impl_getters
.
register
(
P
.
ConcatOffset
)
@
vm_impl_getters
.
register
(
P
.
_grad_ops
.
ConcatOffset
)
def
vm_impl_concatOffset
(
self
):
"""Generate vm_impl function for ConcatOffset"""
def
vm_impl
(
x
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录