Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
74582aaa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
74582aaa
编写于
12月 16, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
12月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
0d tensor for scatter_ and scatter_nd (#49072)
上级
69536892
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
148 addition
and
47 deletion
+148
-47
paddle/phi/infermeta/ternary.cc
paddle/phi/infermeta/ternary.cc
+61
-46
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
+69
-0
python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py
...dle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py
+17
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+1
-1
未找到文件。
paddle/phi/infermeta/ternary.cc
浏览文件 @
74582aaa
...
@@ -1039,6 +1039,20 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
...
@@ -1039,6 +1039,20 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
const
auto
&
updates_dims
=
updates
.
dims
();
const
auto
&
updates_dims
=
updates
.
dims
();
auto
updates_dims_size
=
updates_dims
.
size
();
auto
updates_dims_size
=
updates_dims
.
size
();
if
(
updates_dims_size
==
0
)
{
// check for 0d updates
PADDLE_ENFORCE_EQ
(
index_dims_size
,
1
,
phi
::
errors
::
InvalidArgument
(
"When the updates is a 0d tensor, the "
"index should be a 1d tensor."
));
PADDLE_ENFORCE_EQ
(
index_dims
[
index_dims_size
-
1
],
ref_dims_size
,
phi
::
errors
::
InvalidArgument
(
"When the update is a 0d tensor, The last dimension of "
"Input(Index)'s shape should be equal with the rank of Input(X)."
));
}
else
{
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_LE
(
index_dims
[
index_dims_size
-
1
],
index_dims
[
index_dims_size
-
1
],
ref_dims_size
,
ref_dims_size
,
...
@@ -1063,7 +1077,7 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
...
@@ -1063,7 +1077,7 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
for
(
int64_t
i
=
index_dims
[
index_dims_size
-
1
];
i
<
ref_dims_size
;
++
i
)
{
for
(
int64_t
i
=
index_dims
[
index_dims_size
-
1
];
i
<
ref_dims_size
;
++
i
)
{
r_updates_dims
.
emplace_back
(
ref_dims
[
i
]);
r_updates_dims
.
emplace_back
(
ref_dims
[
i
]);
}
}
// check for non-0d updates
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r_updates_dims
.
size
(),
r_updates_dims
.
size
(),
updates_dims_size
,
updates_dims_size
,
...
@@ -1088,6 +1102,7 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
...
@@ -1088,6 +1102,7 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
i
,
i
,
updates_dims
[
i
]));
updates_dims
[
i
]));
}
}
}
out
->
set_dims
(
ref_dims
);
out
->
set_dims
(
ref_dims
);
out
->
share_lod
(
x
);
out
->
share_lod
(
x
);
out
->
set_dtype
(
x
.
dtype
());
out
->
set_dtype
(
x
.
dtype
());
...
...
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
浏览文件 @
74582aaa
...
@@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase):
...
@@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase):
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
def
test_scatter__1D
(
self
):
x
=
paddle
.
to_tensor
([
1.0
,
3.0
,
5.0
,
7.0
,
9.0
])
index
=
paddle
.
full
([],
2
,
'int64'
)
updates
=
paddle
.
full
([],
4.0
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
self
.
assertEqual
(
out
.
numpy
()[
2
],
4
)
def
test_scatter__XD
(
self
):
x
=
paddle
.
to_tensor
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]])
index
=
paddle
.
full
([],
1
,
'int64'
)
updates
=
paddle
.
to_tensor
([
1.0
,
2.0
,
3.0
])
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
for
i
in
range
(
3
):
self
.
assertEqual
(
out
.
numpy
()[
1
][
i
],
updates
.
numpy
()[
i
])
def
test_scatter_nd
(
self
):
index
=
paddle
.
to_tensor
([
3
],
dtype
=
"int64"
,
stop_gradient
=
False
)
updates
=
paddle
.
full
([],
2
,
dtype
=
'float32'
)
updates
.
stop_gradient
=
False
shape
=
[
5
]
out
=
paddle
.
scatter_nd
(
index
,
updates
,
shape
)
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
5
])
self
.
assertEqual
(
out
.
numpy
()[
3
],
2
)
self
.
assertEqual
(
out
.
grad
.
shape
,
[
5
])
class
TestSundryAPIStatic
(
unittest
.
TestCase
):
class
TestSundryAPIStatic
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase):
...
@@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase):
self
.
assertEqual
(
res2
.
shape
,
(
2
,
2
))
self
.
assertEqual
(
res2
.
shape
,
(
2
,
2
))
self
.
assertEqual
(
res3
.
shape
,
(
1
,
1
))
self
.
assertEqual
(
res3
.
shape
,
(
1
,
1
))
@
prog_scope
()
def
test_scatter__1D
(
self
):
x
=
paddle
.
full
([
10
],
1.0
,
'float32'
)
index
=
paddle
.
full
([],
2
,
'int64'
)
updates
=
paddle
.
full
([],
4
,
'float32'
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
paddle
.
static
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
fetch_list
=
[
out
])
self
.
assertEqual
(
res
[
0
][
2
],
4
)
@
prog_scope
()
def
test_scatter__XD
(
self
):
x
=
paddle
.
full
([
2
,
3
],
1.0
,
'float32'
)
index
=
paddle
.
full
([],
1
,
'int64'
)
updates
=
paddle
.
full
([
3
],
4
,
'float32'
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
paddle
.
static
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
fetch_list
=
[
out
])
for
i
in
range
(
3
):
self
.
assertEqual
(
res
[
0
][
1
][
i
],
4
)
@
prog_scope
()
def
test_scatter_nd
(
self
):
index
=
paddle
.
static
.
data
(
name
=
'index'
,
shape
=
[
1
],
dtype
=
'int64'
)
updates
=
paddle
.
full
([],
2
,
'float32'
)
shape
=
[
5
]
index_data
=
np
.
array
([
3
],
dtype
=
np
.
longlong
)
out
=
paddle
.
scatter_nd
(
index
,
updates
,
shape
)
paddle
.
static
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
feed
=
{
'index'
:
index_data
},
fetch_list
=
[
out
])
self
.
assertEqual
(
res
[
0
].
shape
,
(
5
,))
self
.
assertEqual
(
res
[
0
][
3
],
2
)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py
浏览文件 @
74582aaa
...
@@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase):
...
@@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase):
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
def
test_scatter__1D
(
self
):
x
=
paddle
.
to_tensor
([
1.0
,
3.0
,
5.0
,
7.0
,
9.0
])
index
=
paddle
.
full
([],
2
,
'int64'
)
updates
=
paddle
.
full
([],
4.0
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
self
.
assertEqual
(
out
.
numpy
()[
2
],
4
)
def
test_scatter__XD
(
self
):
x
=
paddle
.
to_tensor
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]])
index
=
paddle
.
full
([],
1
,
'int64'
)
updates
=
paddle
.
to_tensor
([
1.0
,
2.0
,
3.0
])
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
for
i
in
range
(
3
):
self
.
assertEqual
(
out
.
numpy
()[
1
][
i
],
updates
.
numpy
()[
i
])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
74582aaa
...
@@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None):
...
@@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None):
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
Args:
Args:
index (Tensor): The index input with ndim > 1 and index.shape[-1] <= len(shape).
index (Tensor): The index input with ndim >
=
1 and index.shape[-1] <= len(shape).
Its dtype should be int32 or int64 as it is used as indexes.
Its dtype should be int32 or int64 as it is used as indexes.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录