Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cc0add56
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cc0add56
编写于
6月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1930 fix validator for ScatterNdUpdate
Merge pull request !1930 from jiangjinsheng/issue_doc
上级
bd8c623b
dc548afb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
9 deletion
+14
-9
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+14
-9
未找到文件。
mindspore/ops/operations/array_ops.py
浏览文件 @
cc0add56
...
...
@@ -2032,7 +2032,7 @@ class ScatterNd(PrimitiveWithInfer):
Creates an empty tensor, and set values by scattering the update tensor depending on indices.
Inputs:
- **indices** (Tensor) - The index of scattering in the new tensor.
- **indices** (Tensor) - The index of scattering in the new tensor.
With int32 data type.
- **update** (Tensor) - The source Tensor to be scattered.
- **shape** (tuple[int]) - Define the shape of the output tensor. Has the same type as indices.
...
...
@@ -2055,7 +2055,7 @@ class ScatterNd(PrimitiveWithInfer):
def
__infer__
(
self
,
indices
,
update
,
shape
):
shp
=
shape
[
'value'
]
validator
.
check_subclass
(
"update_dtype"
,
update
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
"indices"
:
indices
[
'dtype'
]},
[
mstype
.
int32
]
,
self
.
name
)
validator
.
check_value_type
(
"shape"
,
shp
,
[
tuple
],
self
.
name
)
for
i
,
x
in
enumerate
(
shp
):
validator
.
check_integer
(
"shape[%d]"
%
i
,
x
,
0
,
Rel
.
GT
,
self
.
name
)
...
...
@@ -2159,7 +2159,7 @@ class ScatterUpdate(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor.
With int32 data type.
- **update** (Tensor) - The tensor to update the input tensor, has the same type as input,
and update.shape = indices.shape + input_x.shape[1:].
...
...
@@ -2167,9 +2167,11 @@ class ScatterUpdate(PrimitiveWithInfer):
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> np_update = np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]])
>>> update = Tensor(np_update, mindspore.float32)
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update)
"""
...
...
@@ -2181,6 +2183,7 @@ class ScatterUpdate(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterUpdate"""
validator
.
check_value_type
(
'use_locking'
,
use_locking
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
...
...
@@ -2189,7 +2192,7 @@ class ScatterUpdate(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
[
mstype
.
int32
]
,
self
.
name
)
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
x_dtype
...
...
@@ -2206,14 +2209,15 @@ class ScatterNdUpdate(PrimitiveWithInfer):
Inputs:
- **input_x** (Parameter) - The target tensor, with data type of Parameter.
- **indices** (Tensor) - The index of input tensor.
- **indices** (Tensor) - The index of input tensor
, with int32 data type
.
- **update** (Tensor) - The tensor to add to the input tensor, has the same type as input.
Outputs:
Tensor, has the same shape and type as `input_x`.
Examples:
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> np_x = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])
>>> input_x = mindspore.Parameter(Tensor(np_x, mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.ScatterNdUpdate()
...
...
@@ -2227,6 +2231,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterNdUpdate"""
validator
.
check_value_type
(
'use_locking'
,
use_locking
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
...
...
@@ -2237,7 +2242,7 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_dtype
,
indices_dtype
,
value_dtype
):
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
mstype
.
int_type
,
self
.
name
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
[
mstype
.
int32
]
,
self
.
name
)
args
=
{
"x"
:
x_dtype
,
"value"
:
value_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
x_dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录