Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cc0add56
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cc0add56
编写于
6月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1930 fix validator for ScatterNdUpdate
Merge pull request !1930 from jiangjinsheng/issue_doc
上级
bd8c623b
dc548afb
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
14 addition
and
9 deletion
+14
-9
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+14
-9
未找到文件。
mindspore/ops/operations/array_ops.py
浏览文件 @
cc0add56
...
@@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
...
@@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
Creates an empty tensor, and set values by scattering the update tensor depending on indices.
Creates an empty tensor, and set values by scattering the update tensor depending on indices.
Inputs:
Inputs:
- **indices** (Tensor) - The index of scattering in the new tensor.
- **indices** (Tensor) - The index of scattering in the new tensor.
With int32 data type.
- **update** (Tensor) - The source Tensor to be scattered.
- **update** (Tensor) - The source Tensor to be scattered.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
...
@@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
...
@@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
def
__infer__
(
self
,
indices
,
update
,
shape
):
def
__infer__
(
self
,
indices
,
update
,
shape
):
shp
=
shape
[
'value'
]
shp
=
shape
[
'value'
]
validator
.
check_subclass
(
"update_dtype"
,
update
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"update_dtype"
,
update
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
[
mstype
.
int32
]
,
self
.
name
)
validator
.
check_value_type
(
"shape"
,
shp
,
[
tuple
],
self
.
name
)
validator
.
check_value_type
(
"shape"
,
shp
,
[
tuple
],
self
.
name
)
for
i
,
x
in
enumerate
(
shp
):
for
i
,
x
in
enumerate
(
shp
):
validator
.
check_integer
(
"shape[%d]"
%
i
,
x
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check_integer
(
"shape[%d]"
%
i
,
x
,
0
,
Rel
.
GT
,
self
.
name
)
...
@@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
...
@@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
Inputs:
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor.
With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
and update.shape = indices.shape + input_x.shape[1:].
...
@@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
...
@@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
Tensor, has the same shape and type as `input_x`.
Tensor, has the same shape and type as `input_x`.
Examples:
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update)
>>> output = op(input_x, indices, update)
"""
"""
...
@@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
...
@@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterUpdate"""
"""Init ScatterUpdate"""
validator
.
check_value_type
(
'use_locking'
,
use_locking
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
...
@@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
...
@@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
[
mstype
.
int32
]
,
self
.
name
)
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
...
@@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Inputs:
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor
, with int32 data type
.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
Outputs:
Outputs:
Tensor, has the same shape and type as `input_x`.
Tensor, has the same shape and type as `input_x`.
Examples:
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
>>> op = P.ScatterNdUpdate()
...
@@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
...
@@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterNdUpdate"""
"""Init ScatterNdUpdate"""
validator
.
check_value_type
(
'use_locking'
,
use_locking
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
...
@@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
...
@@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
[
mstype
.
int32
]
,
self
.
name
)
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
x_dtype
return
x_dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录