Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
66682be0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
66682be0
编写于
1月 31, 2023
作者:
R
RedContritio
提交者:
GitHub
1月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix 堆栈溢出 (stack overflow) of case9: paddle.repeat_interleave (#49982)
* support negative index in repeat_interleave * add unittest
上级
baf96a12
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
46 addition
and
13 deletion
+46
-13
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+26
-13
python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py
...paddle/fluid/tests/unittests/test_repeat_interleave_op.py
+20
-0
未找到文件。
paddle/phi/infermeta/unary.cc
浏览文件 @
66682be0
...
...
@@ -3075,27 +3075,40 @@ void RepeatInterleaveInferMeta(const MetaTensor& x,
MetaTensor
*
out
)
{
const
auto
&
input_dim
=
x
.
dims
();
auto
output_dim
=
phi
::
vectorize
(
input_dim
);
auto
n_dim
=
dim
;
PADDLE_ENFORCE_EQ
(
dim
<
input_dim
.
size
()
&&
dim
>=
(
0
-
input_dim
.
size
()),
true
,
if
(
n_dim
<
0
)
n_dim
+=
input_dim
.
size
();
PADDLE_ENFORCE_LT
(
dim
,
input_dim
.
size
(),
phi
::
errors
::
OutOfRange
(
"Attr(dim) is out of range, It's expected "
"to be in range of [
-
%d, %d]. But received Attr(dim) = %d."
,
input_dim
.
size
(),
"to be in range of [%d, %d]. But received Attr(dim) = %d."
,
-
input_dim
.
size
(),
input_dim
.
size
()
-
1
,
dim
));
PADDLE_ENFORCE_EQ
(
repeats
>
0
,
true
,
PADDLE_ENFORCE_GE
(
dim
,
(
0
-
input_dim
.
size
()),
phi
::
errors
::
OutOfRange
(
"Attr(dim) is out of range, It's expected "
"to be in range of [%d, %d]. But received Attr(dim) = %d."
,
-
input_dim
.
size
(),
input_dim
.
size
()
-
1
,
dim
));
PADDLE_ENFORCE_GT
(
repeats
,
0
,
phi
::
errors
::
InvalidArgument
(
"repeats should be larger than zero"
));
PADDLE_ENFORCE_N
E
(
out
,
nullptr
,
phi
::
errors
::
InvalidArgument
(
"repeat_interleave's output tensor can't be nullptr"
));
PADDLE_ENFORCE_N
OT_NULL
(
out
,
phi
::
errors
::
InvalidArgument
(
"repeat_interleave's output tensor can't be nullptr"
));
output_dim
[
dim
]
=
input_dim
[
dim
]
*
repeats
;
output_dim
[
n_dim
]
=
input_dim
[
n_
dim
]
*
repeats
;
out
->
set_dims
(
phi
::
make_ddim
(
output_dim
));
out
->
share_lod
(
x
);
out
->
set_dtype
(
x
.
dtype
());
...
...
python/paddle/fluid/tests/unittests/test_repeat_interleave_op.py
浏览文件 @
66682be0
...
...
@@ -188,6 +188,26 @@ class TestIndexSelectAPI(unittest.TestCase):
expect_out
=
np
.
repeat
(
self
.
data_zero_dim_x
,
repeats
)
np
.
testing
.
assert_allclose
(
expect_out
,
np
.
array
(
res
),
rtol
=
1e-05
)
# case 4 negative axis:
with
program_guard
(
Program
(),
Program
()):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
-
1
,
4
],
dtype
=
'float32'
)
x
.
desc
.
set_need_check_feed
(
False
)
index
=
paddle
.
static
.
data
(
name
=
'repeats_'
,
shape
=
[
4
],
dtype
=
'int32'
,
)
index
.
desc
.
set_need_check_feed
(
False
)
z
=
paddle
.
repeat_interleave
(
x
,
index
,
axis
=-
1
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
(
res
,)
=
exe
.
run
(
feed
=
{
'x'
:
self
.
data_x
,
'repeats_'
:
self
.
data_index
},
fetch_list
=
[
z
.
name
],
return_numpy
=
False
,
)
expect_out
=
np
.
repeat
(
self
.
data_x
,
self
.
data_index
,
axis
=-
1
)
np
.
testing
.
assert_allclose
(
expect_out
,
np
.
array
(
res
),
rtol
=
1e-05
)
def
test_dygraph_api
(
self
):
self
.
input_data
()
# case axis none
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录