Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
74582aaa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
74582aaa
编写于
12月 16, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
12月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
0d tensor for scatter_ and scatter_nd (#49072)
上级
69536892
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
148 addition
and
47 deletion
+148
-47
paddle/phi/infermeta/ternary.cc
paddle/phi/infermeta/ternary.cc
+61
-46
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
+69
-0
python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py
...dle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py
+17
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+1
-1
未找到文件。
paddle/phi/infermeta/ternary.cc
浏览文件 @
74582aaa
...
@@ -1039,54 +1039,69 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
...
@@ -1039,54 +1039,69 @@ void ScatterNdAddInferMeta(const MetaTensor& x,
const
auto
&
updates_dims
=
updates
.
dims
();
const
auto
&
updates_dims
=
updates
.
dims
();
auto
updates_dims_size
=
updates_dims
.
size
();
auto
updates_dims_size
=
updates_dims
.
size
();
PADDLE_ENFORCE_LE
(
if
(
updates_dims_size
==
0
)
{
index_dims
[
index_dims_size
-
1
],
// check for 0d updates
ref_dims_size
,
PADDLE_ENFORCE_EQ
(
phi
::
errors
::
InvalidArgument
(
index_dims_size
,
"The last dimension of Input(Index)'s shape should be no greater "
1
,
"than the rank of Input(X), but received the last dimension of "
phi
::
errors
::
InvalidArgument
(
"When the updates is a 0d tensor, the "
"Input(Index)'s shape is %d, the rank of Input(X) is %d."
,
"index should be a 1d tensor."
));
index_dims
[
index_dims_size
-
1
],
ref_dims_size
));
PADDLE_ENFORCE_GE
(
index_dims_size
,
2UL
,
phi
::
errors
::
InvalidArgument
(
"The rank of Input(Index) should be greater than 1, "
"but received the rank of Input(Index) is %d."
,
index_dims_size
));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std
::
vector
<
int64_t
>
r_updates_dims
;
for
(
int64_t
i
=
0
;
i
<
index_dims_size
-
1
;
++
i
)
{
r_updates_dims
.
emplace_back
(
index_dims
[
i
]);
}
for
(
int64_t
i
=
index_dims
[
index_dims_size
-
1
];
i
<
ref_dims_size
;
++
i
)
{
r_updates_dims
.
emplace_back
(
ref_dims
[
i
]);
}
PADDLE_ENFORCE_EQ
(
r_updates_dims
.
size
(),
updates_dims_size
,
phi
::
errors
::
InvalidArgument
(
"Updates has wrong shape. The shape of Updates and Input(Updates) "
"should be same, but received the shape of Updates is %d, "
"the shape of Input(Updates) is %d."
,
r_updates_dims
.
size
(),
updates_dims_size
));
for
(
int64_t
i
=
0
;
i
<
updates_dims_size
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
r_updates_dims
[
i
],
index_dims
[
index_dims_size
-
1
],
updates_dims
[
i
]
,
ref_dims_size
,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"Updates has wrong shape. The dimensions of Updates and "
"When the update is a 0d tensor, The last dimension of "
"Input(Updates) should match, but received Updates's"
"Input(Index)'s shape should be equal with the rank of Input(X)."
));
"%d-th dimension is %d, Input(Updates)'s %d-th "
}
else
{
"dimension is %d."
,
PADDLE_ENFORCE_LE
(
i
,
index_dims
[
index_dims_size
-
1
],
r_updates_dims
[
i
],
ref_dims_size
,
i
,
phi
::
errors
::
InvalidArgument
(
updates_dims
[
i
]));
"The last dimension of Input(Index)'s shape should be no greater "
"than the rank of Input(X), but received the last dimension of "
"Input(Index)'s shape is %d, the rank of Input(X) is %d."
,
index_dims
[
index_dims_size
-
1
],
ref_dims_size
));
PADDLE_ENFORCE_GE
(
index_dims_size
,
2UL
,
phi
::
errors
::
InvalidArgument
(
"The rank of Input(Index) should be greater than 1, "
"but received the rank of Input(Index) is %d."
,
index_dims_size
));
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
std
::
vector
<
int64_t
>
r_updates_dims
;
for
(
int64_t
i
=
0
;
i
<
index_dims_size
-
1
;
++
i
)
{
r_updates_dims
.
emplace_back
(
index_dims
[
i
]);
}
for
(
int64_t
i
=
index_dims
[
index_dims_size
-
1
];
i
<
ref_dims_size
;
++
i
)
{
r_updates_dims
.
emplace_back
(
ref_dims
[
i
]);
}
// check for non-0d updates
PADDLE_ENFORCE_EQ
(
r_updates_dims
.
size
(),
updates_dims_size
,
phi
::
errors
::
InvalidArgument
(
"Updates has wrong shape. The shape of Updates and Input(Updates) "
"should be same, but received the shape of Updates is %d, "
"the shape of Input(Updates) is %d."
,
r_updates_dims
.
size
(),
updates_dims_size
));
for
(
int64_t
i
=
0
;
i
<
updates_dims_size
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
r_updates_dims
[
i
],
updates_dims
[
i
],
phi
::
errors
::
InvalidArgument
(
"Updates has wrong shape. The dimensions of Updates and "
"Input(Updates) should match, but received Updates's"
"%d-th dimension is %d, Input(Updates)'s %d-th "
"dimension is %d."
,
i
,
r_updates_dims
[
i
],
i
,
updates_dims
[
i
]));
}
}
}
out
->
set_dims
(
ref_dims
);
out
->
set_dims
(
ref_dims
);
out
->
share_lod
(
x
);
out
->
share_lod
(
x
);
...
...
python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
浏览文件 @
74582aaa
...
@@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase):
...
@@ -682,6 +682,36 @@ class TestSundryAPI(unittest.TestCase):
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
def
test_scatter__1D
(
self
):
x
=
paddle
.
to_tensor
([
1.0
,
3.0
,
5.0
,
7.0
,
9.0
])
index
=
paddle
.
full
([],
2
,
'int64'
)
updates
=
paddle
.
full
([],
4.0
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
self
.
assertEqual
(
out
.
numpy
()[
2
],
4
)
def
test_scatter__XD
(
self
):
x
=
paddle
.
to_tensor
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]])
index
=
paddle
.
full
([],
1
,
'int64'
)
updates
=
paddle
.
to_tensor
([
1.0
,
2.0
,
3.0
])
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
for
i
in
range
(
3
):
self
.
assertEqual
(
out
.
numpy
()[
1
][
i
],
updates
.
numpy
()[
i
])
def
test_scatter_nd
(
self
):
index
=
paddle
.
to_tensor
([
3
],
dtype
=
"int64"
,
stop_gradient
=
False
)
updates
=
paddle
.
full
([],
2
,
dtype
=
'float32'
)
updates
.
stop_gradient
=
False
shape
=
[
5
]
out
=
paddle
.
scatter_nd
(
index
,
updates
,
shape
)
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
5
])
self
.
assertEqual
(
out
.
numpy
()[
3
],
2
)
self
.
assertEqual
(
out
.
grad
.
shape
,
[
5
])
class
TestSundryAPIStatic
(
unittest
.
TestCase
):
class
TestSundryAPIStatic
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase):
...
@@ -845,6 +875,45 @@ class TestSundryAPIStatic(unittest.TestCase):
self
.
assertEqual
(
res2
.
shape
,
(
2
,
2
))
self
.
assertEqual
(
res2
.
shape
,
(
2
,
2
))
self
.
assertEqual
(
res3
.
shape
,
(
1
,
1
))
self
.
assertEqual
(
res3
.
shape
,
(
1
,
1
))
@
prog_scope
()
def
test_scatter__1D
(
self
):
x
=
paddle
.
full
([
10
],
1.0
,
'float32'
)
index
=
paddle
.
full
([],
2
,
'int64'
)
updates
=
paddle
.
full
([],
4
,
'float32'
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
paddle
.
static
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
fetch_list
=
[
out
])
self
.
assertEqual
(
res
[
0
][
2
],
4
)
@
prog_scope
()
def
test_scatter__XD
(
self
):
x
=
paddle
.
full
([
2
,
3
],
1.0
,
'float32'
)
index
=
paddle
.
full
([],
1
,
'int64'
)
updates
=
paddle
.
full
([
3
],
4
,
'float32'
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
paddle
.
static
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
fetch_list
=
[
out
])
for
i
in
range
(
3
):
self
.
assertEqual
(
res
[
0
][
1
][
i
],
4
)
@
prog_scope
()
def
test_scatter_nd
(
self
):
index
=
paddle
.
static
.
data
(
name
=
'index'
,
shape
=
[
1
],
dtype
=
'int64'
)
updates
=
paddle
.
full
([],
2
,
'float32'
)
shape
=
[
5
]
index_data
=
np
.
array
([
3
],
dtype
=
np
.
longlong
)
out
=
paddle
.
scatter_nd
(
index
,
updates
,
shape
)
paddle
.
static
.
append_backward
(
out
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
feed
=
{
'index'
:
index_data
},
fetch_list
=
[
out
])
self
.
assertEqual
(
res
[
0
].
shape
,
(
5
,))
self
.
assertEqual
(
res
[
0
][
3
],
2
)
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py
浏览文件 @
74582aaa
...
@@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase):
...
@@ -504,6 +504,23 @@ class TestSundryAPI(unittest.TestCase):
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x2
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
self
.
assertEqual
(
x3
.
grad
.
shape
,
[])
def
test_scatter__1D
(
self
):
x
=
paddle
.
to_tensor
([
1.0
,
3.0
,
5.0
,
7.0
,
9.0
])
index
=
paddle
.
full
([],
2
,
'int64'
)
updates
=
paddle
.
full
([],
4.0
)
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
self
.
assertEqual
(
out
.
numpy
()[
2
],
4
)
def
test_scatter__XD
(
self
):
x
=
paddle
.
to_tensor
([[
1.0
,
2.0
,
3.0
],
[
4.0
,
5.0
,
6.0
]])
index
=
paddle
.
full
([],
1
,
'int64'
)
updates
=
paddle
.
to_tensor
([
1.0
,
2.0
,
3.0
])
out
=
paddle
.
scatter_
(
x
,
index
,
updates
)
for
i
in
range
(
3
):
self
.
assertEqual
(
out
.
numpy
()[
1
][
i
],
updates
.
numpy
()[
i
])
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
class
TestNoBackwardAPI
(
unittest
.
TestCase
):
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
74582aaa
...
@@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None):
...
@@ -3078,7 +3078,7 @@ def scatter_nd(index, updates, shape, name=None):
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
seen :code:`scatter_nd_add` . This op is the inverse of the :code:`gather_nd` op.
Args:
Args:
index (Tensor): The index input with ndim > 1 and index.shape[-1] <= len(shape).
index (Tensor): The index input with ndim >
=
1 and index.shape[-1] <= len(shape).
Its dtype should be int32 or int64 as it is used as indexes.
Its dtype should be int32 or int64 as it is used as indexes.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
updates (Tensor): The updated value of scatter_nd op. Its dtype should be float32, float64.
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
It must have the shape index.shape[:-1] + shape[index.shape[-1]:]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录