Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3d4d995f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d4d995f
编写于
6月 14, 2023
作者:
zhouweiwei2014
提交者:
GitHub
6月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Zero-Dim] paddle.nanmedian/nanquantile support 0D Tensor (#54500)
* [Zero-Dim] paddle.nanmedian support 0D Tensor * fix CI
上级
ca59c72b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
359 addition
and
376 deletion
+359
-376
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+30
-20
paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
+35
-38
paddle/phi/kernels/cpu/nanmedian_kernel.cc
paddle/phi/kernels/cpu/nanmedian_kernel.cc
+28
-41
paddle/phi/kernels/funcs/nanmedian_utils.h
paddle/phi/kernels/funcs/nanmedian_utils.h
+44
-1
paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu
paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu
+42
-40
paddle/phi/kernels/gpu/nanmedian_kernel.cu
paddle/phi/kernels/gpu/nanmedian_kernel.cu
+31
-38
paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h
paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h
+0
-65
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+7
-26
test/legacy_test/test_nanmedian.py
test/legacy_test/test_nanmedian.py
+34
-1
test/legacy_test/test_zero_dim_tensor.py
test/legacy_test/test_zero_dim_tensor.py
+108
-106
未找到文件。
paddle/phi/infermeta/unary.cc
浏览文件 @
3d4d995f
...
@@ -2323,37 +2323,47 @@ void NanmedianInferMeta(const MetaTensor& x,
...
@@ -2323,37 +2323,47 @@ void NanmedianInferMeta(const MetaTensor& x,
for
(
int64_t
i
=
0
;
i
<
x_rank
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
x_rank
;
i
++
)
{
out_dim
.
push_back
(
1
);
out_dim
.
push_back
(
1
);
}
}
}
else
{
out_dim
.
push_back
(
1
);
}
}
}
else
{
}
else
{
std
::
vector
<
int64_t
>
clean
ed_axis
;
std
::
vector
<
int64_t
>
format
ed_axis
;
for
(
auto
&
axis
:
axis_list
)
{
for
(
auto
&
axis
:
axis_list
)
{
if
(
x_rank
==
0
)
{
PADDLE_ENFORCE_EQ
(
axis
==
0
||
axis
==
-
1
,
true
,
phi
::
errors
::
InvalidArgument
(
"When input 0D Tensor, each element of the axis "
"can only be -1, 0, None"
));
}
else
{
PADDLE_ENFORCE_LT
(
axis
,
x_rank
,
errors
::
InvalidArgument
(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d."
,
x_rank
,
axis
));
PADDLE_ENFORCE_GE
(
axis
,
-
x_rank
,
errors
::
InvalidArgument
(
"each element of the axis should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received axis = %d."
,
x_rank
,
axis
));
}
if
(
axis
<
0
)
axis
+=
x_rank
;
if
(
axis
<
0
)
axis
+=
x_rank
;
PADDLE_ENFORCE_LT
(
axis
,
x_rank
,
errors
::
InvalidArgument
(
"Attr(axis) value should be in range [-R, R-1], R is "
"the rank of Input(X). But received axis: %d, R: %d. "
"Current Input(X)'s shape is=[%s]."
,
axis
,
x_rank
,
x_dim
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
std
::
find
(
cleaned_axis
.
begin
(),
clean
ed_axis
.
end
(),
axis
),
std
::
find
(
formated_axis
.
begin
(),
format
ed_axis
.
end
(),
axis
),
clean
ed_axis
.
end
(),
format
ed_axis
.
end
(),
errors
::
InvalidArgument
(
"Attr(axes) has duplicated elements: %d."
,
errors
::
InvalidArgument
(
"Attr(axes) has duplicated elements: %d."
,
static_cast
<
int
>
(
axis
)));
static_cast
<
int
>
(
axis
)));
clean
ed_axis
.
push_back
(
axis
);
format
ed_axis
.
push_back
(
axis
);
}
}
for
(
int64_t
i
=
0
;
i
<
x_rank
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
x_rank
;
i
++
)
{
if
(
std
::
find
(
cleaned_axis
.
begin
(),
clean
ed_axis
.
end
(),
i
)
==
if
(
std
::
find
(
formated_axis
.
begin
(),
format
ed_axis
.
end
(),
i
)
==
clean
ed_axis
.
end
())
{
format
ed_axis
.
end
())
{
out_dim
.
push_back
(
x_dim
[
i
]);
out_dim
.
push_back
(
x_dim
[
i
]);
}
else
if
(
keep_dim
)
{
}
else
if
(
keep_dim
)
{
out_dim
.
push_back
(
1
);
out_dim
.
push_back
(
1
);
...
...
paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc
浏览文件 @
3d4d995f
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/
impl/nanmedian_grad_kernel_impl
.h"
#include "paddle/phi/kernels/
funcs/nanmedian_utils
.h"
namespace
phi
{
namespace
phi
{
...
@@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx,
...
@@ -26,67 +26,64 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
const
IntArray
&
axes
UNUSED
,
DenseTensor
*
x_grad
)
{
DenseTensor
*
x_grad
,
T
*
dx_data
=
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
T
*
x_grad_ptr
)
{
if
(
!
dx_data
)
return
;
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
if
(
!
x_grad_ptr
)
return
;
const
int64_t
*
m_
ptr
=
median_index
.
data
<
int64_t
>
();
const
int64_t
*
m_
data
=
median_index
.
data
<
int64_t
>
();
const
T
*
out_grad_ptr
=
out_grad
.
data
<
T
>
();
const
T
*
dout_data
=
out_grad
.
data
<
T
>
();
int64_t
numel
=
x
.
numel
();
int64_t
numel
=
x
.
numel
();
auto
x_dim
=
x
.
dims
();
auto
x_dim
=
x
.
dims
();
int64_t
rank
=
x_dim
.
size
();
int64_t
rank
=
x_dim
.
size
();
int64_t
stride
=
x_dim
[
rank
-
1
];
int64_t
stride
=
x_dim
[
rank
-
1
];
int64_t
pre_dim
=
numel
/
stride
;
int64_t
pre_dim
=
numel
/
stride
;
int64_t
i
=
0
;
int64_t
i
=
0
;
int64_t
offset
=
0
;
int64_t
offset
=
0
;
T
div_factor
=
static_cast
<
T
>
(
2.0
);
for
(
i
=
0
;
i
<
pre_dim
;
i
++
)
{
for
(
i
=
0
;
i
<
pre_dim
;
i
++
)
{
if
(
m_
ptr
[
2
*
i
]
>=
0
)
{
if
(
m_
data
[
2
*
i
]
>=
0
)
{
if
(
m_
ptr
[
2
*
i
]
==
m_ptr
[
2
*
i
+
1
])
{
if
(
m_
data
[
2
*
i
]
==
m_data
[
2
*
i
+
1
])
{
x_grad_ptr
[
offset
+
m_ptr
[
2
*
i
]]
=
out_grad_ptr
[
i
];
dx_data
[
offset
+
m_data
[
2
*
i
]]
=
dout_data
[
i
];
}
else
{
}
else
{
x_grad_ptr
[
offset
+
m_ptr
[
2
*
i
]]
=
out_grad_ptr
[
i
]
/
div_factor
;
dx_data
[
offset
+
m_data
[
2
*
i
]]
=
dout_data
[
i
]
/
static_cast
<
T
>
(
2.0
);
x_grad_ptr
[
offset
+
m_ptr
[
2
*
i
+
1
]]
=
out_grad_ptr
[
i
]
/
div_factor
;
dx_data
[
offset
+
m_data
[
2
*
i
+
1
]]
=
dout_data
[
i
]
/
static_cast
<
T
>
(
2.0
);
}
}
}
}
offset
+=
stride
;
offset
+=
stride
;
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
BaseMedianGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
out_grad
,
const
IntArray
&
axes
,
DenseTensor
*
x_grad
)
{
auto
rank
=
x
.
dims
().
size
();
T
*
x_grad_ptr
=
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
if
(
axes
.
size
()
&&
(
rank
>
1
))
{
DenseTensor
tmp_x_grad
(
*
x_grad
);
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
median_index
,
out_grad
,
axes
,
&
tmp_x_grad
,
x_grad_ptr
);
PostprocessMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
&
tmp_x_grad
,
axes
,
x_grad
);
}
else
{
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
median_index
,
out_grad
,
axes
,
x_grad
,
x_grad_ptr
);
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
NanmedianGradKernel
(
const
Context
&
dev_ctx
,
void
NanmedianGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
x
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
const
IntArray
&
axes
,
const
IntArray
&
axes
,
bool
keep
_
dim
UNUSED
,
bool
keepdim
UNUSED
,
DenseTensor
*
x_grad
)
{
DenseTensor
*
x_grad
)
{
BaseMedianGradKernel
<
T
,
Context
>
(
DenseTensor
tmp_x
;
dev_ctx
,
input
,
median_index
,
out_grad
,
axes
,
x_grad
);
auto
rank
=
x
.
dims
().
size
();
if
((
axes
.
size
()
==
0
)
||
rank
<=
1
)
{
tmp_x
=
x
;
tmp_x
.
Resize
({
x
.
numel
()});
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_x
,
median_index
,
out_grad
,
x_grad
);
}
else
{
funcs
::
PreprocessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axes
,
&
tmp_x
);
DenseTensor
tmp_x_grad
;
tmp_x_grad
.
Resize
(
x_grad
->
dims
());
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_x
,
median_index
,
out_grad
,
&
tmp_x_grad
);
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
funcs
::
PostprocessMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
&
tmp_x_grad
,
axes
,
x_grad
);
}
}
}
}
// namespace phi
}
// namespace phi
...
...
paddle/phi/kernels/cpu/nanmedian_kernel.cc
浏览文件 @
3d4d995f
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/
impl/nanmedian_kernel_impl
.h"
#include "paddle/phi/kernels/
funcs/nanmedian_utils
.h"
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/phi/kernels/top_k_kernel.h"
namespace
phi
{
namespace
phi
{
...
@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
...
@@ -31,7 +31,6 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t
pre_dim
,
int64_t
pre_dim
,
T
*
o_ptr
,
T
*
o_ptr
,
int64_t
*
m_ptr
)
{
int64_t
*
m_ptr
)
{
bool
should_ignore_nan
=
ignore_nan
;
DenseTensor
sort_out
;
DenseTensor
sort_out
;
DenseTensor
sort_indices
;
DenseTensor
sort_indices
;
auto
sort_dim
=
x
.
dims
();
auto
sort_dim
=
x
.
dims
();
...
@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
...
@@ -52,7 +51,7 @@ void CalcMedianFunc(const Context& dev_ctx,
int64_t
offset
=
0
;
int64_t
offset
=
0
;
int64_t
i
=
0
;
int64_t
i
=
0
;
bool
is_ori_odd
=
stride
&
1
;
bool
is_ori_odd
=
stride
&
1
;
if
(
should_
ignore_nan
)
{
if
(
ignore_nan
)
{
for
(
i
=
0
;
i
<
pre_dim
;
i
++
)
{
for
(
i
=
0
;
i
<
pre_dim
;
i
++
)
{
offset
=
i
*
sort_k
;
offset
=
i
*
sort_k
;
if
(
nan_counts
[
i
]
==
stride
)
{
if
(
nan_counts
[
i
]
==
stride
)
{
...
@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
...
@@ -107,11 +106,11 @@ void CalcMedianFunc(const Context& dev_ctx,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ProcessMedianKernel
(
const
Context
&
dev_ctx
,
void
ProcessMedianKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
T
*
o_ptr
,
DenseTensor
*
out
,
int64_t
*
m_ptr
,
DenseTensor
*
median_index
)
{
bool
ignore_nan
)
{
const
T
*
x_data
=
x
.
data
<
T
>
();
bool
should_ignore_nan
=
ignore_nan
;
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
)
;
const
T
*
x_ptr
=
x
.
data
<
T
>
(
);
int64_t
*
m_data
=
dev_ctx
.
template
Alloc
<
int64_t
>(
median_index
);
int64_t
numel
=
x
.
numel
();
int64_t
numel
=
x
.
numel
();
auto
x_dim
=
x
.
dims
();
auto
x_dim
=
x
.
dims
();
...
@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -122,7 +121,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
int64_t
max_valid_num
=
0
;
int64_t
max_valid_num
=
0
;
std
::
vector
<
int64_t
>
nan_counts
;
std
::
vector
<
int64_t
>
nan_counts
;
if
(
should_ignore_nan
)
{
bool
ignore_nan
=
true
;
if
(
ignore_nan
)
{
int64_t
total_nan_num
=
0
;
int64_t
total_nan_num
=
0
;
std
::
vector
<
T
>
col_vec
;
std
::
vector
<
T
>
col_vec
;
col_vec
.
reserve
(
stride
);
col_vec
.
reserve
(
stride
);
...
@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -133,7 +133,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
for
(
int64_t
i
=
0
;
i
<
pre_dim
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
pre_dim
;
i
++
)
{
col_vec
.
clear
();
col_vec
.
clear
();
col_vec
.
insert
(
col_vec
.
insert
(
col_vec
.
begin
(),
x_
ptr
+
i
*
stride
,
x_ptr
+
(
i
+
1
)
*
stride
);
col_vec
.
begin
(),
x_
data
+
i
*
stride
,
x_data
+
(
i
+
1
)
*
stride
);
nan_counts
[
i
]
=
nan_counts
[
i
]
=
std
::
count_if
(
col_vec
.
begin
(),
col_vec
.
end
(),
[
&
](
const
T
&
val
)
{
std
::
count_if
(
col_vec
.
begin
(),
col_vec
.
end
(),
[
&
](
const
T
&
val
)
{
return
std
::
isnan
(
static_cast
<
float
>
(
val
));
return
std
::
isnan
(
static_cast
<
float
>
(
val
));
...
@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -145,47 +145,25 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elems are nan
// all elems are nan
if
(
total_nan_num
==
numel
)
{
if
(
total_nan_num
==
numel
)
{
for
(
i
=
0
;
i
<
pre_dim
;
i
++
)
{
for
(
i
=
0
;
i
<
pre_dim
;
i
++
)
{
o
_ptr
[
i
]
=
x_ptr
[
0
]
;
o
ut_data
[
i
]
=
std
::
numeric_limits
<
T
>::
quiet_NaN
()
;
m_
ptr
[
2
*
i
]
=
-
1
;
m_
data
[
2
*
i
]
=
-
1
;
m_
ptr
[
2
*
i
+
1
]
=
-
1
;
m_
data
[
2
*
i
+
1
]
=
-
1
;
}
}
return
;
return
;
}
}
should_
ignore_nan
=
total_nan_num
>
0
;
ignore_nan
=
total_nan_num
>
0
;
}
}
int64_t
sort_k
=
should_
ignore_nan
?
max_valid_num
:
((
stride
>>
1
)
+
1
);
int64_t
sort_k
=
ignore_nan
?
max_valid_num
:
((
stride
>>
1
)
+
1
);
CalcMedianFunc
<
T
,
Context
>
(
dev_ctx
,
CalcMedianFunc
<
T
,
Context
>
(
dev_ctx
,
x
,
x
,
nan_counts
,
nan_counts
,
should_
ignore_nan
,
ignore_nan
,
sort_k
,
sort_k
,
stride
,
stride
,
pre_dim
,
pre_dim
,
o_ptr
,
out_data
,
m_ptr
);
m_data
);
}
template
<
typename
T
,
typename
Context
>
void
BaseMedianKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
IntArray
&
axes
,
DenseTensor
*
out
,
DenseTensor
*
median_index
,
bool
ignore_nan
)
{
DenseTensor
x
;
auto
rank
=
input
.
dims
().
size
();
if
((
axes
.
size
()
==
0
)
||
rank
<=
1
)
{
x
=
input
;
x
.
Resize
({
input
.
numel
()});
}
else
{
PreprocessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
input
,
axes
,
&
x
);
}
T
*
o_ptr
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
int64_t
*
m_ptr
=
dev_ctx
.
template
Alloc
<
int64_t
>(
median_index
);
ProcessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
o_ptr
,
m_ptr
,
ignore_nan
);
out
->
Resize
(
out
->
dims
());
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
...
@@ -195,7 +173,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool
keepdim
UNUSED
,
bool
keepdim
UNUSED
,
DenseTensor
*
out
,
DenseTensor
*
out
,
DenseTensor
*
median_index
)
{
DenseTensor
*
median_index
)
{
BaseMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axes
,
out
,
median_index
,
true
);
DenseTensor
tmp_x
;
auto
rank
=
x
.
dims
().
size
();
if
((
axes
.
size
()
==
0
)
||
rank
<=
1
)
{
tmp_x
=
x
;
tmp_x
.
Resize
({
x
.
numel
()});
}
else
{
funcs
::
PreprocessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axes
,
&
tmp_x
);
}
ProcessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_x
,
out
,
median_index
);
}
}
}
// namespace phi
}
// namespace phi
...
...
paddle/phi/kernels/
impl/nanmedian_kernel_impl
.h
→
paddle/phi/kernels/
funcs/nanmedian_utils
.h
浏览文件 @
3d4d995f
...
@@ -15,9 +15,51 @@
...
@@ -15,9 +15,51 @@
#pragma once
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
funcs
{
template
<
typename
T
,
typename
Context
>
void
PostprocessMedianGradKernel
(
const
Context
&
dev_ctx
,
DenseTensor
*
input
,
const
IntArray
&
raw_axes
,
DenseTensor
*
x
)
{
auto
input_dim
=
input
->
dims
();
auto
rank
=
input_dim
.
size
();
std
::
vector
<
int64_t
>
axes
=
raw_axes
.
GetData
();
int64_t
axes_size
=
static_cast
<
int
>
(
axes
.
size
());
for
(
int64_t
i
=
0
;
i
<
axes_size
;
i
++
)
{
if
(
axes
[
i
]
<
0
)
{
axes
[
i
]
+=
rank
;
}
}
std
::
vector
<
int
>
trans_back
;
std
::
vector
<
int
>
reshape_back
;
trans_back
.
resize
(
rank
);
int
offset
=
0
;
for
(
int64_t
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
reshape_back
.
push_back
(
input_dim
[
i
]);
trans_back
[
i
]
=
offset
;
offset
+=
1
;
}
}
for
(
int64_t
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
trans_back
[
i
]
=
offset
;
reshape_back
.
push_back
(
input_dim
[
i
]);
offset
+=
1
;
}
}
input
->
Resize
(
make_ddim
(
reshape_back
));
funcs
::
TransCompute
<
Context
,
T
>
(
static_cast
<
int
>
(
trans_back
.
size
()),
dev_ctx
,
*
input
,
x
,
trans_back
);
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
PreprocessMedianKernel
(
const
Context
&
dev_ctx
,
void
PreprocessMedianKernel
(
const
Context
&
dev_ctx
,
...
@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx,
...
@@ -65,4 +107,5 @@ void PreprocessMedianKernel(const Context& dev_ctx,
x
->
Resize
(
make_ddim
(
reshape
));
x
->
Resize
(
make_ddim
(
reshape
));
}
}
}
// namespace funcs
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu
浏览文件 @
3d4d995f
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/
impl/nanmedian_grad_kernel_impl
.h"
#include "paddle/phi/kernels/
funcs/nanmedian_utils
.h"
namespace
phi
{
namespace
phi
{
...
@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) {
...
@@ -30,23 +30,26 @@ inline int GET_BLOCKS(const int N) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
KernelNanmedianGrad
(
const
T
*
x_
ptr
,
__global__
void
KernelNanmedianGrad
(
const
T
*
x_
data
,
const
int64_t
*
medians_ptr
,
const
int64_t
*
medians_ptr
,
const
T
*
out_grad_ptr
,
const
T
*
out_grad_ptr
,
T
*
x_grad_ptr
,
T
*
dx_data
,
int64_t
stride
,
int64_t
stride
,
int64_t
pre_dim
,
int64_t
pre_dim
)
{
T
div_factor
)
{
CUDA_KERNEL_LOOP
(
index
,
pre_dim
)
{
CUDA_KERNEL_LOOP
(
index
,
pre_dim
)
{
int64_t
offset
=
index
*
stride
;
int64_t
offset
=
index
*
stride
;
printf
(
"index: %d
\n
"
,
index
);
printf
(
"medians_ptr[2 * index]: %d
\n
"
,
medians_ptr
[
2
*
index
]);
printf
(
"medians_ptr[2 * index+1]: %d
\n
"
,
medians_ptr
[
2
*
index
+
1
]);
if
(
medians_ptr
[
2
*
index
]
>=
0
)
{
if
(
medians_ptr
[
2
*
index
]
>=
0
)
{
if
(
medians_ptr
[
2
*
index
]
==
medians_ptr
[
2
*
index
+
1
])
{
if
(
medians_ptr
[
2
*
index
]
==
medians_ptr
[
2
*
index
+
1
])
{
x_grad_ptr
[
offset
+
medians_ptr
[
2
*
index
]]
=
out_grad_ptr
[
index
];
dx_data
[
offset
+
medians_ptr
[
2
*
index
]]
=
out_grad_ptr
[
index
];
}
else
{
}
else
{
x_grad_ptr
[
offset
+
medians_ptr
[
2
*
index
]]
=
dx_data
[
offset
+
medians_ptr
[
2
*
index
]]
=
out_grad_ptr
[
index
]
/
div_factor
;
out_grad_ptr
[
index
]
/
static_cast
<
T
>
(
2.0
)
;
x_grad_ptr
[
offset
+
medians_ptr
[
2
*
index
+
1
]]
=
dx_data
[
offset
+
medians_ptr
[
2
*
index
+
1
]]
=
out_grad_ptr
[
index
]
/
div_factor
;
out_grad_ptr
[
index
]
/
static_cast
<
T
>
(
2.0
)
;
}
}
}
}
}
}
...
@@ -57,14 +60,17 @@ void CalcMedianGradKernel(const Context& dev_ctx,
...
@@ -57,14 +60,17 @@ void CalcMedianGradKernel(const Context& dev_ctx,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
,
DenseTensor
*
x_grad
)
{
T
*
x_grad_ptr
)
{
T
*
dx_data
=
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
if
(
!
dx_data
)
return
;
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
VLOG
(
0
)
<<
"x_grad->dims(): "
<<
x_grad
->
dims
();
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
const
T
*
x_
ptr
=
x
.
data
<
T
>
();
const
T
*
x_
data
=
x
.
data
<
T
>
();
const
int64_t
*
m_
ptr
=
median_index
.
data
<
int64_t
>
();
const
int64_t
*
m_
data
=
median_index
.
data
<
int64_t
>
();
const
T
*
out_grad_ptr
=
out_grad
.
data
<
T
>
();
const
T
*
out_grad_ptr
=
out_grad
.
data
<
T
>
();
int64_t
numel
=
x
.
numel
();
int64_t
numel
=
x
.
numel
();
...
@@ -73,42 +79,38 @@ void CalcMedianGradKernel(const Context& dev_ctx,
...
@@ -73,42 +79,38 @@ void CalcMedianGradKernel(const Context& dev_ctx,
int64_t
stride
=
x_dim
[
x_rank
-
1
];
int64_t
stride
=
x_dim
[
x_rank
-
1
];
int64_t
pre_dim
=
numel
/
stride
;
int64_t
pre_dim
=
numel
/
stride
;
T
div_factor
=
static_cast
<
T
>
(
2.0
);
KernelNanmedianGrad
<
T
>
KernelNanmedianGrad
<
T
>
<<<
GET_BLOCKS
(
pre_dim
),
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
<<<
GET_BLOCKS
(
pre_dim
),
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
x_ptr
,
m_ptr
,
out_grad_ptr
,
x_grad_ptr
,
stride
,
pre_dim
,
div_factor
);
x_data
,
m_data
,
out_grad_ptr
,
dx_data
,
stride
,
pre_dim
);
}
template
<
typename
T
,
typename
Context
>
void
BaseMedianGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
out_grad
,
const
IntArray
&
axes
,
DenseTensor
*
x_grad
)
{
auto
rank
=
x
.
dims
().
size
();
T
*
x_grad_ptr
=
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
if
(
axes
.
size
()
&&
(
rank
>
1
))
{
DenseTensor
tmp_x_grad
(
*
x_grad
);
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
median_index
,
out_grad
,
&
tmp_x_grad
,
x_grad_ptr
);
PostprocessMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
&
tmp_x_grad
,
axes
,
x_grad
);
}
else
{
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
median_index
,
out_grad
,
x_grad
,
x_grad_ptr
);
}
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
NanmedianGradKernel
(
const
Context
&
dev_ctx
,
void
NanmedianGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
x
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
median_index
,
const
DenseTensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
const
IntArray
&
axes
,
const
IntArray
&
axes
,
bool
keep
_dim
,
bool
keep
dim
UNUSED
,
DenseTensor
*
x_grad
)
{
DenseTensor
*
x_grad
)
{
BaseMedianGradKernel
<
T
,
Context
>
(
DenseTensor
tmp_x
;
dev_ctx
,
input
,
median_index
,
out_grad
,
axes
,
x_grad
);
auto
rank
=
x
.
dims
().
size
();
if
((
axes
.
size
()
==
0
)
||
rank
<=
1
)
{
tmp_x
=
x
;
tmp_x
.
Resize
({
x
.
numel
()});
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_x
,
median_index
,
out_grad
,
x_grad
);
}
else
{
funcs
::
PreprocessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axes
,
&
tmp_x
);
DenseTensor
tmp_x_grad
;
tmp_x_grad
.
Resize
(
x_grad
->
dims
());
CalcMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_x
,
median_index
,
out_grad
,
&
tmp_x_grad
);
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
funcs
::
PostprocessMedianGradKernel
<
T
,
Context
>
(
dev_ctx
,
&
tmp_x_grad
,
axes
,
x_grad
);
}
}
}
}
// namespace phi
}
// namespace phi
...
...
paddle/phi/kernels/gpu/nanmedian_kernel.cu
浏览文件 @
3d4d995f
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/
impl/nanmedian_kernel_impl
.h"
#include "paddle/phi/kernels/
funcs/nanmedian_utils
.h"
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/phi/kernels/top_k_kernel.h"
namespace
phi
{
namespace
phi
{
...
@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr,
...
@@ -138,14 +138,13 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ProcessMedianKernel
(
const
Context
&
dev_ctx
,
void
ProcessMedianKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
bool
ignore_nan
,
DenseTensor
*
out
,
DenseTensor
*
out
,
int64_t
*
m_ptr
)
{
DenseTensor
*
median_index
)
{
bool
should_ignore_nan
=
ignore_nan
;
auto
stream
=
dev_ctx
.
stream
();
auto
stream
=
dev_ctx
.
stream
();
const
T
*
x_data
=
x
.
data
<
T
>
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
int64_t
*
m_data
=
dev_ctx
.
template
Alloc
<
int64_t
>(
median_index
);
const
T
*
x_ptr
=
x
.
data
<
T
>
();
T
*
o_ptr
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
int64_t
numel
=
x
.
numel
();
int64_t
numel
=
x
.
numel
();
auto
x_dim
=
x
.
dims
();
auto
x_dim
=
x
.
dims
();
int64_t
x_rank
=
x_dim
.
size
();
int64_t
x_rank
=
x_dim
.
size
();
...
@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -156,7 +155,9 @@ void ProcessMedianKernel(const Context& dev_ctx,
DenseTensor
nan_counts
,
nan_stat
;
DenseTensor
nan_counts
,
nan_stat
;
int64_t
*
nan_counts_ptr
;
int64_t
*
nan_counts_ptr
;
int64_t
max_valid_num
=
0
;
int64_t
max_valid_num
=
0
;
if
(
should_ignore_nan
)
{
bool
ignore_nan
=
true
;
if
(
ignore_nan
)
{
nan_counts
.
Resize
(
phi
::
make_ddim
({
pre_dim
}));
nan_counts
.
Resize
(
phi
::
make_ddim
({
pre_dim
}));
dev_ctx
.
template
Alloc
<
int64_t
>(
&
nan_counts
);
dev_ctx
.
template
Alloc
<
int64_t
>(
&
nan_counts
);
nan_counts_ptr
=
nan_counts
.
data
<
int64_t
>
();
nan_counts_ptr
=
nan_counts
.
data
<
int64_t
>
();
...
@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -167,7 +168,7 @@ void ProcessMedianKernel(const Context& dev_ctx,
KernelNanCounts
<
T
><<<
GET_BLOCKS
(
numel
),
KernelNanCounts
<
T
><<<
GET_BLOCKS
(
numel
),
PADDLE_CUDA_NUM_THREADS
,
PADDLE_CUDA_NUM_THREADS
,
pre_dim
*
sizeof
(
int64_t
),
pre_dim
*
sizeof
(
int64_t
),
stream
>>>
(
x_
ptr
,
stream
>>>
(
x_
data
,
numel
,
numel
,
pre_dim
,
pre_dim
,
stride
,
stride
,
...
@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -189,15 +190,19 @@ void ProcessMedianKernel(const Context& dev_ctx,
// all elements are nan values
// all elements are nan values
T
nan_val
=
std
::
numeric_limits
<
T
>::
quiet_NaN
();
T
nan_val
=
std
::
numeric_limits
<
T
>::
quiet_NaN
();
if
(
nan_stat_cpu_ptr
[
0
]
==
numel
)
{
if
(
nan_stat_cpu_ptr
[
0
]
==
numel
)
{
FullLikeKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
nan_val
,
x
.
dtype
(),
out
);
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_nan
;
set_nan
(
dev_ctx
,
out
,
nan_val
);
phi
::
funcs
::
SetConstant
<
Context
,
int64_t
>
set_negatvie
;
set_negatvie
(
dev_ctx
,
median_index
,
static_cast
<
int64_t
>
(
-
1
));
return
;
return
;
}
}
should_
ignore_nan
=
nan_stat_cpu_ptr
[
0
]
>
0
;
ignore_nan
=
nan_stat_cpu_ptr
[
0
]
>
0
;
max_valid_num
=
nan_stat_cpu_ptr
[
1
];
max_valid_num
=
nan_stat_cpu_ptr
[
1
];
}
}
int64_t
sort_k
=
should_
ignore_nan
?
max_valid_num
:
((
stride
>>
1
)
+
1
);
int64_t
sort_k
=
ignore_nan
?
max_valid_num
:
((
stride
>>
1
)
+
1
);
bool
is_ori_odd
=
stride
&
1
;
bool
is_ori_odd
=
stride
&
1
;
DenseTensor
sort_out
,
sort_indices
;
DenseTensor
sort_out
,
sort_indices
;
...
@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -217,14 +222,14 @@ void ProcessMedianKernel(const Context& dev_ctx,
T
div_factor
=
static_cast
<
T
>
(
2.0
);
T
div_factor
=
static_cast
<
T
>
(
2.0
);
T
nan_val
=
std
::
numeric_limits
<
T
>::
quiet_NaN
();
T
nan_val
=
std
::
numeric_limits
<
T
>::
quiet_NaN
();
if
(
should_
ignore_nan
)
{
if
(
ignore_nan
)
{
CalcNanmedianKernel
<
T
>
CalcNanmedianKernel
<
T
>
<<<
GET_BLOCKS
(
pre_dim
),
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
<<<
GET_BLOCKS
(
pre_dim
),
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
sort_out_ptr
,
sort_out_ptr
,
sort_indices_ptr
,
sort_indices_ptr
,
nan_counts_ptr
,
nan_counts_ptr
,
m_
ptr
,
m_
data
,
o
_ptr
,
o
ut_data
,
is_ori_odd
,
is_ori_odd
,
pre_dim
,
pre_dim
,
max_valid_num
,
max_valid_num
,
...
@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -236,8 +241,8 @@ void ProcessMedianKernel(const Context& dev_ctx,
<<<
GET_BLOCKS
(
pre_dim
),
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
<<<
GET_BLOCKS
(
pre_dim
),
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
sort_out_ptr
,
sort_out_ptr
,
sort_indices_ptr
,
sort_indices_ptr
,
m_
ptr
,
m_
data
,
o
_ptr
,
o
ut_data
,
div_factor
,
div_factor
,
is_ori_odd
,
is_ori_odd
,
pre_dim
,
pre_dim
,
...
@@ -245,27 +250,6 @@ void ProcessMedianKernel(const Context& dev_ctx,
...
@@ -245,27 +250,6 @@ void ProcessMedianKernel(const Context& dev_ctx,
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
BaseMedianKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
IntArray
&
axes
,
bool
ignore_nan
,
DenseTensor
*
out
,
DenseTensor
*
median_index
)
{
DenseTensor
x
;
auto
rank
=
input
.
dims
().
size
();
if
((
axes
.
size
()
==
0
)
||
rank
<=
1
)
{
x
=
input
;
x
.
Resize
({
input
.
numel
()});
}
else
{
PreprocessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
input
,
axes
,
&
x
);
}
int64_t
*
m_ptr
=
dev_ctx
.
template
Alloc
<
int64_t
>(
median_index
);
ProcessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
ignore_nan
,
out
,
m_ptr
);
out
->
Resize
(
out
->
dims
());
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
NanmedianKernel
(
const
Context
&
dev_ctx
,
void
NanmedianKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -273,7 +257,16 @@ void NanmedianKernel(const Context& dev_ctx,
...
@@ -273,7 +257,16 @@ void NanmedianKernel(const Context& dev_ctx,
bool
keepdim
,
bool
keepdim
,
DenseTensor
*
out
,
DenseTensor
*
out
,
DenseTensor
*
median_index
)
{
DenseTensor
*
median_index
)
{
BaseMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axes
,
true
,
out
,
median_index
);
DenseTensor
tmp_x
;
auto
rank
=
x
.
dims
().
size
();
if
((
axes
.
size
()
==
0
)
||
rank
<=
1
)
{
tmp_x
=
x
;
tmp_x
.
Resize
({
x
.
numel
()});
}
else
{
funcs
::
PreprocessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axes
,
&
tmp_x
);
}
ProcessMedianKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_x
,
out
,
median_index
);
}
}
}
// namespace phi
}
// namespace phi
...
...
paddle/phi/kernels/impl/nanmedian_grad_kernel_impl.h
已删除
100644 → 0
浏览文件 @
ca59c72b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/nanmedian_grad_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
PostprocessMedianGradKernel
(
const
Context
&
dev_ctx
,
DenseTensor
*
input
,
const
IntArray
&
raw_axes
,
DenseTensor
*
x
)
{
auto
input_dim
=
input
->
dims
();
auto
rank
=
input_dim
.
size
();
std
::
vector
<
int64_t
>
axes
=
raw_axes
.
GetData
();
int64_t
axes_size
=
static_cast
<
int
>
(
axes
.
size
());
for
(
int64_t
i
=
0
;
i
<
axes_size
;
i
++
)
{
if
(
axes
[
i
]
<
0
)
{
axes
[
i
]
+=
rank
;
}
}
std
::
vector
<
int
>
trans_back
;
std
::
vector
<
int
>
reshape_back
;
trans_back
.
reserve
(
rank
);
trans_back
.
resize
(
rank
);
int
offset
=
0
;
for
(
int64_t
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
==
axes
.
end
())
{
reshape_back
.
push_back
(
input_dim
[
i
]);
trans_back
[
i
]
=
offset
;
offset
+=
1
;
}
}
for
(
int64_t
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
{
trans_back
[
i
]
=
offset
;
reshape_back
.
push_back
(
input_dim
[
i
]);
offset
+=
1
;
}
}
input
->
Resize
(
make_ddim
(
reshape_back
));
funcs
::
TransCompute
<
Context
,
T
>
(
static_cast
<
int
>
(
trans_back
.
size
()),
dev_ctx
,
*
input
,
x
,
trans_back
);
}
}
// namespace phi
python/paddle/tensor/stat.py
浏览文件 @
3d4d995f
...
@@ -255,7 +255,7 @@ def numel(x, name=None):
...
@@ -255,7 +255,7 @@ def numel(x, name=None):
return
out
return
out
def
nanmedian
(
x
,
axis
=
None
,
keepdim
=
Tru
e
,
name
=
None
):
def
nanmedian
(
x
,
axis
=
None
,
keepdim
=
Fals
e
,
name
=
None
):
r
"""
r
"""
Compute the median along the specified axis, while ignoring NaNs.
Compute the median along the specified axis, while ignoring NaNs.
...
@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
...
@@ -273,7 +273,7 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
in the output Tensor. If ``keepdim`` is True, the dimensions of
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is
Tru
e.
the output Tensor is squeezed in ``axis`` . Default is
Fals
e.
name (str, optional): Name for the operation (optional, default is None).
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
For more information, please refer to :ref:`api_guide_Name`.
...
@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
...
@@ -287,16 +287,16 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]])
x = paddle.to_tensor([[float('nan'), 2. , 3. ], [0. , 1. , 2. ]])
y1 = x.nanmedian()
y1 = x.nanmedian()
# y1 is
[[2.]]
# y1 is
2.
y2 = x.nanmedian(0)
y2 = x.nanmedian(0)
# y2 is [
[0., 1.5, 2.5]
]
# y2 is [
0., 1.5, 2.5
]
y3 = x.nanmedian(0, keepdim=
Fals
e)
y3 = x.nanmedian(0, keepdim=
Tru
e)
# y3 is [
0., 1.5, 2.5
]
# y3 is [
[0., 1.5, 2.5]
]
y4 = x.nanmedian((0, 1))
y4 = x.nanmedian((0, 1))
# y4 is
[[2.]]
# y4 is
2.
"""
"""
if
not
isinstance
(
x
,
Variable
):
if
not
isinstance
(
x
,
Variable
):
raise
TypeError
(
"In median, the input x should be a Tensor."
)
raise
TypeError
(
"In median, the input x should be a Tensor."
)
...
@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
...
@@ -304,7 +304,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
if
isinstance
(
axis
,
(
list
,
tuple
))
and
len
(
axis
)
==
0
:
if
isinstance
(
axis
,
(
list
,
tuple
))
and
len
(
axis
)
==
0
:
raise
ValueError
(
"Axis list should not be empty."
)
raise
ValueError
(
"Axis list should not be empty."
)
dims
=
len
(
x
.
shape
)
if
axis
is
None
:
if
axis
is
None
:
axis
=
[]
axis
=
[]
elif
isinstance
(
axis
,
tuple
):
elif
isinstance
(
axis
,
tuple
):
...
@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
...
@@ -312,24 +311,6 @@ def nanmedian(x, axis=None, keepdim=True, name=None):
elif
isinstance
(
axis
,
int
):
elif
isinstance
(
axis
,
int
):
axis
=
[
axis
]
axis
=
[
axis
]
if
not
isinstance
(
axis
,
list
):
raise
ValueError
(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
for
i
in
range
(
len
(
axis
)):
if
not
isinstance
(
axis
[
i
],
int
)
or
not
(
axis
[
i
]
<
dims
and
axis
[
i
]
>=
-
dims
):
raise
ValueError
(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if
axis
[
i
]
<
0
:
axis
[
i
]
+=
dims
if
len
(
axis
)
!=
len
(
set
(
axis
)):
raise
ValueError
(
"Axis has duplicated elements."
)
if
in_dynamic_mode
():
if
in_dynamic_mode
():
return
_C_ops
.
nanmedian
(
x
,
axis
,
keepdim
)
return
_C_ops
.
nanmedian
(
x
,
axis
,
keepdim
)
else
:
else
:
...
...
test/legacy_test/test_nanmedian.py
浏览文件 @
3d4d995f
...
@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase):
...
@@ -125,6 +125,7 @@ class TestNanmedian(unittest.TestCase):
pd_res
=
paddle
.
nanmedian
(
pd_res
=
paddle
.
nanmedian
(
paddle
.
to_tensor
(
data
),
keepdim
=
keep_dim
paddle
.
to_tensor
(
data
),
keepdim
=
keep_dim
)
)
assert
np_res
.
shape
==
pd_res
.
numpy
().
shape
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
np_res
,
pd_res
.
numpy
(),
rtol
=
1e-05
,
equal_nan
=
True
np_res
,
pd_res
.
numpy
(),
rtol
=
1e-05
,
equal_nan
=
True
)
)
...
@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase):
...
@@ -187,6 +188,23 @@ class TestNanmedian(unittest.TestCase):
x_np
[
0
,
:]
=
np
.
nan
x_np
[
0
,
:]
=
np
.
nan
x_np
[
1
,
:
3
]
=
np
.
nan
x_np
[
1
,
:
3
]
=
np
.
nan
x_np
[
2
,
3
:]
=
np
.
nan
x_np
[
2
,
3
:]
=
np
.
nan
x_tensor
=
paddle
.
to_tensor
(
x_np
,
stop_gradient
=
False
)
y
=
paddle
.
nanmedian
(
x_tensor
,
keepdim
=
True
)
dx
=
paddle
.
grad
(
y
,
x_tensor
)[
0
].
numpy
()
np_grad
=
np
.
zeros
(
shape
)
np_grad
[
1
,
3
]
=
0.5
np_grad
[
3
,
2
]
=
0.5
np
.
testing
.
assert_allclose
(
np_grad
,
dx
,
rtol
=
1e-05
,
equal_nan
=
True
)
def
test_check_grad_axis
(
self
):
paddle
.
disable_static
(
place
=
self
.
place
)
shape
=
(
4
,
5
)
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
shape
).
astype
(
np
.
float64
)
x_np
[
0
,
:]
=
np
.
nan
x_np
[
1
,
:
3
]
=
np
.
nan
x_np
[
2
,
3
:]
=
np
.
nan
x_np_sorted
=
np
.
sort
(
x_np
)
x_np_sorted
=
np
.
sort
(
x_np
)
nan_counts
=
np
.
count_nonzero
(
np
.
isnan
(
x_np
).
astype
(
np
.
int32
),
axis
=
1
)
nan_counts
=
np
.
count_nonzero
(
np
.
isnan
(
x_np
).
astype
(
np
.
int32
),
axis
=
1
)
np_grad
=
np
.
zeros
(
shape
)
np_grad
=
np
.
zeros
(
shape
)
...
@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase):
...
@@ -205,10 +223,25 @@ class TestNanmedian(unittest.TestCase):
np_grad
[
i
,
j
]
=
1
if
is_odd
else
0.5
np_grad
[
i
,
j
]
=
1
if
is_odd
else
0.5
x_tensor
=
paddle
.
to_tensor
(
x_np
,
stop_gradient
=
False
)
x_tensor
=
paddle
.
to_tensor
(
x_np
,
stop_gradient
=
False
)
y
=
paddle
.
nanmedian
(
x_tensor
,
axis
=
1
,
keepdim
=
True
)
y
=
paddle
.
nanmedian
(
x_tensor
,
axis
=
1
)
dx
=
paddle
.
grad
(
y
,
x_tensor
)[
0
].
numpy
()
dx
=
paddle
.
grad
(
y
,
x_tensor
)[
0
].
numpy
()
np
.
testing
.
assert_allclose
(
np_grad
,
dx
,
rtol
=
1e-05
,
equal_nan
=
True
)
np
.
testing
.
assert_allclose
(
np_grad
,
dx
,
rtol
=
1e-05
,
equal_nan
=
True
)
def
test_check_grad_0d
(
self
):
paddle
.
disable_static
(
place
=
self
.
place
)
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
y
=
paddle
.
nanmedian
(
x
)
y
.
backward
()
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
np
.
testing
.
assert_allclose
(
x
.
grad
,
np
.
array
(
1.0
))
x
=
paddle
.
to_tensor
(
float
(
'nan'
),
stop_gradient
=
False
)
y
=
paddle
.
nanmedian
(
x
)
y
.
backward
()
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
np
.
testing
.
assert_allclose
(
x
.
grad
,
np
.
array
(
0.0
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
test/legacy_test/test_zero_dim_tensor.py
浏览文件 @
3d4d995f
...
@@ -179,6 +179,8 @@ reduce_api_list = [
...
@@ -179,6 +179,8 @@ reduce_api_list = [
paddle
.
mean
,
paddle
.
mean
,
paddle
.
nansum
,
paddle
.
nansum
,
paddle
.
nanmean
,
paddle
.
nanmean
,
paddle
.
median
,
paddle
.
nanmedian
,
paddle
.
min
,
paddle
.
min
,
paddle
.
max
,
paddle
.
max
,
paddle
.
amin
,
paddle
.
amin
,
...
@@ -202,7 +204,7 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -202,7 +204,7 @@ class TestReduceAPI(unittest.TestCase):
else
:
else
:
x
=
paddle
.
rand
([])
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
out
=
api
(
x
,
None
)
out
=
api
(
x
,
axis
=
None
)
out
.
retain_grads
()
out
.
retain_grads
()
out
.
backward
()
out
.
backward
()
...
@@ -212,9 +214,10 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -212,9 +214,10 @@ class TestReduceAPI(unittest.TestCase):
if
api
not
in
[
paddle
.
count_nonzero
]:
if
api
not
in
[
paddle
.
count_nonzero
]:
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
x
.
numpy
())
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
x
.
numpy
())
out_empty_list
=
api
(
x
,
[])
if
api
not
in
[
paddle
.
median
,
paddle
.
nanmedian
]:
self
.
assertEqual
(
out_empty_list
,
out
)
out_empty_list
=
api
(
x
,
axis
=
[])
self
.
assertEqual
(
out_empty_list
.
shape
,
[])
self
.
assertEqual
(
out_empty_list
,
out
)
self
.
assertEqual
(
out_empty_list
.
shape
,
[])
if
x
.
grad
is
not
None
:
if
x
.
grad
is
not
None
:
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
...
@@ -222,12 +225,12 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -222,12 +225,12 @@ class TestReduceAPI(unittest.TestCase):
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
np
.
array
(
1.0
))
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
np
.
array
(
1.0
))
np
.
testing
.
assert_allclose
(
out
.
grad
.
numpy
(),
np
.
array
(
1.0
))
np
.
testing
.
assert_allclose
(
out
.
grad
.
numpy
(),
np
.
array
(
1.0
))
out1
=
api
(
x
,
0
)
out1
=
api
(
x
,
axis
=
0
)
self
.
assertEqual
(
out1
.
shape
,
[])
self
.
assertEqual
(
out1
.
shape
,
[])
self
.
assertEqual
(
out1
,
out
)
self
.
assertEqual
(
out1
,
out
)
out1
.
backward
()
out1
.
backward
()
out2
=
api
(
x
,
-
1
)
out2
=
api
(
x
,
axis
=
-
1
)
self
.
assertEqual
(
out2
.
shape
,
[])
self
.
assertEqual
(
out2
.
shape
,
[])
self
.
assertEqual
(
out2
,
out
)
self
.
assertEqual
(
out2
,
out
)
out2
.
backward
()
out2
.
backward
()
...
@@ -236,13 +239,28 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -236,13 +239,28 @@ class TestReduceAPI(unittest.TestCase):
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
np
.
array
(
3.0
))
np
.
testing
.
assert_allclose
(
x
.
grad
.
numpy
(),
np
.
array
(
3.0
))
# 2) x is ND, reduce to 0D
# 2) x is 1D, axis=0, reduce to 0D
if
api
in
[
paddle
.
all
,
paddle
.
any
]:
x
=
paddle
.
randint
(
0
,
2
,
[
5
]).
astype
(
'bool'
)
else
:
x
=
paddle
.
rand
([
5
])
x
.
stop_gradient
=
False
out
=
api
(
x
,
axis
=
0
)
out
.
retain_grads
()
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
if
x
.
grad
is
not
None
:
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
5
])
# 3) x is ND, reduce to 0D
if
api
in
[
paddle
.
all
,
paddle
.
any
]:
if
api
in
[
paddle
.
all
,
paddle
.
any
]:
x
=
paddle
.
randint
(
0
,
2
,
[
3
,
5
]).
astype
(
'bool'
)
x
=
paddle
.
randint
(
0
,
2
,
[
3
,
5
]).
astype
(
'bool'
)
else
:
else
:
x
=
paddle
.
rand
([
3
,
5
])
x
=
paddle
.
rand
([
3
,
5
])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
out
=
api
(
x
,
None
)
out
=
api
(
x
,
axis
=
None
)
out
.
retain_grads
()
out
.
retain_grads
()
out
.
backward
()
out
.
backward
()
...
@@ -251,20 +269,20 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -251,20 +269,20 @@ class TestReduceAPI(unittest.TestCase):
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
3
,
5
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
3
,
5
])
#
3) x is 1D, axis=0, reduce to 0D
#
4) x is ND, reduce to 0D, keepdim=True
if
api
in
[
paddle
.
all
,
paddle
.
any
]:
if
api
in
[
paddle
.
all
,
paddle
.
any
]:
x
=
paddle
.
randint
(
0
,
2
,
[
5
]).
astype
(
'bool'
)
x
=
paddle
.
randint
(
0
,
2
,
[
3
,
5
]).
astype
(
'bool'
)
else
:
else
:
x
=
paddle
.
rand
([
5
])
x
=
paddle
.
rand
([
3
,
5
])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
out
=
api
(
x
,
0
)
out
=
api
(
x
,
keepdim
=
True
)
out
.
retain_grads
()
out
.
retain_grads
()
out
.
backward
()
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
])
if
x
.
grad
is
not
None
:
if
x
.
grad
is
not
None
:
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[
1
,
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
5
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
3
,
5
])
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -283,16 +301,17 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -283,16 +301,17 @@ class TestReduceAPI(unittest.TestCase):
else
:
else
:
x
=
paddle
.
rand
([])
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
out
=
api
(
x
,
None
)
out
=
api
(
x
,
axis
=
None
)
paddle
.
static
.
append_backward
(
out
)
paddle
.
static
.
append_backward
(
out
)
out_empty_list
=
api
(
x
,
None
)
if
api
not
in
[
paddle
.
median
,
paddle
.
nanmedian
]:
self
.
assertEqual
(
out_empty_list
.
shape
,
())
out_empty_list
=
api
(
x
,
axis
=
[])
self
.
assertEqual
(
out_empty_list
.
shape
,
())
out1
=
api
(
x
,
0
)
out1
=
api
(
x
,
axis
=
0
)
self
.
assertEqual
(
out1
.
shape
,
())
self
.
assertEqual
(
out1
.
shape
,
())
out2
=
api
(
x
,
-
1
)
out2
=
api
(
x
,
axis
=
-
1
)
self
.
assertEqual
(
out2
.
shape
,
())
self
.
assertEqual
(
out2
.
shape
,
())
fetch_list
=
[
x
,
out
]
fetch_list
=
[
x
,
out
]
...
@@ -317,7 +336,7 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -317,7 +336,7 @@ class TestReduceAPI(unittest.TestCase):
else
:
else
:
x
=
paddle
.
rand
([
3
,
5
])
x
=
paddle
.
rand
([
3
,
5
])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
out
=
api
(
x
,
None
)
out
=
api
(
x
,
axis
=
None
)
paddle
.
static
.
append_backward
(
out
)
paddle
.
static
.
append_backward
(
out
)
fetch_list
=
[
out
]
fetch_list
=
[
out
]
...
@@ -336,7 +355,7 @@ class TestReduceAPI(unittest.TestCase):
...
@@ -336,7 +355,7 @@ class TestReduceAPI(unittest.TestCase):
else
:
else
:
x
=
paddle
.
rand
([
5
])
x
=
paddle
.
rand
([
5
])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
out
=
api
(
x
,
0
)
out
=
api
(
x
,
axis
=
0
)
paddle
.
static
.
append_backward
(
out
)
paddle
.
static
.
append_backward
(
out
)
fetch_list
=
[
out
]
fetch_list
=
[
out
]
...
@@ -1200,54 +1219,6 @@ class TestSundryAPI(unittest.TestCase):
...
@@ -1200,54 +1219,6 @@ class TestSundryAPI(unittest.TestCase):
out
=
paddle
.
argmax
(
x
,
keepdim
=
True
)
out
=
paddle
.
argmax
(
x
,
keepdim
=
True
)
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
])
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
])
def
test_median
(
self
):
# 1) x is 0D
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out1
=
paddle
.
median
(
x
,
0
)
out2
=
paddle
.
median
(
x
,
-
1
)
out3
=
paddle
.
median
(
x
,
None
)
out1
.
backward
()
out2
.
backward
()
out3
.
backward
()
self
.
assertEqual
(
out1
.
shape
,
[])
np
.
testing
.
assert_allclose
(
out1
,
x
)
self
.
assertEqual
(
out2
.
shape
,
[])
np
.
testing
.
assert_allclose
(
out2
,
x
)
self
.
assertEqual
(
out3
.
shape
,
[])
np
.
testing
.
assert_allclose
(
out3
,
x
)
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
np
.
testing
.
assert_allclose
(
x
.
grad
,
3.0
)
# 2) x is 1D
x
=
paddle
.
rand
([
5
])
x
.
stop_gradient
=
False
out
=
paddle
.
median
(
x
,
0
)
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
5
])
# 3) x is ND
x
=
paddle
.
rand
([
3
,
5
])
x
.
stop_gradient
=
False
out
=
paddle
.
median
(
x
,
None
)
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
3
,
5
])
# 4) x is ND, keepdim=True
x
=
paddle
.
rand
([
3
,
5
])
x
.
stop_gradient
=
False
out
=
paddle
.
median
(
x
,
keepdim
=
True
)
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[
1
,
1
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
3
,
5
])
def
test_kthvalue
(
self
):
def
test_kthvalue
(
self
):
# 1) x is 0D
# 1) x is 0D
x
=
paddle
.
randn
([])
x
=
paddle
.
randn
([])
...
@@ -1535,6 +1506,40 @@ class TestSundryAPI(unittest.TestCase):
...
@@ -1535,6 +1506,40 @@ class TestSundryAPI(unittest.TestCase):
self
.
assertEqual
(
out
.
grad
,
1.0
)
self
.
assertEqual
(
out
.
grad
,
1.0
)
self
.
assertEqual
(
x
.
grad
.
shape
,
[
2
,
3
])
self
.
assertEqual
(
x
.
grad
.
shape
,
[
2
,
3
])
def
test_nanquantile
(
self
):
# 1) x is 0D
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
quantile
(
x
,
0.5
,
axis
=
None
)
out
.
retain_grads
()
out
.
backward
()
out_empty_list
=
paddle
.
quantile
(
x
,
0.5
,
axis
=
[])
self
.
assertEqual
(
out_empty_list
,
out
)
self
.
assertEqual
(
x
.
shape
,
[])
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
out
,
x
)
self
.
assertEqual
(
x
.
grad
.
shape
,
[])
self
.
assertEqual
(
x
.
grad
,
1.0
)
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
,
1.0
)
# 2) x is ND with 'nan'
x
=
paddle
.
to_tensor
([[
float
(
'nan'
),
2.0
,
3.0
],
[
0.0
,
1.0
,
2.0
]])
x
.
stop_gradient
=
False
out
=
paddle
.
quantile
(
x
,
0.5
,
axis
=
None
)
out
.
retain_grads
()
out
.
backward
()
self
.
assertEqual
(
out
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
.
shape
,
[])
self
.
assertEqual
(
out
.
grad
,
1.0
)
self
.
assertEqual
(
x
.
grad
.
shape
,
[
2
,
3
])
def
test_flip
(
self
):
def
test_flip
(
self
):
x
=
paddle
.
rand
([])
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
x
.
stop_gradient
=
False
...
@@ -3442,40 +3447,6 @@ class TestSundryAPIStatic(unittest.TestCase):
...
@@ -3442,40 +3447,6 @@ class TestSundryAPIStatic(unittest.TestCase):
np
.
testing
.
assert_allclose
(
res
[
2
],
0.0
)
np
.
testing
.
assert_allclose
(
res
[
2
],
0.0
)
self
.
assertEqual
(
res
[
3
].
shape
,
())
self
.
assertEqual
(
res
[
3
].
shape
,
())
@
prog_scope
()
def
test_median
(
self
):
# 1) x is 0D
x
=
paddle
.
rand
([])
x
.
stop_gradient
=
False
out
=
paddle
.
median
(
x
)
paddle
.
static
.
append_backward
(
out
)
# 2) x is ND
x1
=
paddle
.
rand
([
3
,
5
])
x1
.
stop_gradient
=
False
out1
=
paddle
.
median
(
x1
)
paddle
.
static
.
append_backward
(
out1
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
fetch_list
=
[
x
,
out
,
x
.
grad_name
,
out1
,
x1
.
grad_name
,
],
)
self
.
assertEqual
(
res
[
1
].
shape
,
())
np
.
testing
.
assert_allclose
(
res
[
1
],
res
[
0
])
self
.
assertEqual
(
res
[
2
].
shape
,
())
np
.
testing
.
assert_allclose
(
res
[
2
],
1.0
)
self
.
assertEqual
(
res
[
3
].
shape
,
())
self
.
assertEqual
(
res
[
4
].
shape
,
(
3
,
5
))
@
prog_scope
()
@
prog_scope
()
def
test_kthvalue
(
self
):
def
test_kthvalue
(
self
):
# 1) x is 0D
# 1) x is 0D
...
@@ -3813,12 +3784,12 @@ class TestSundryAPIStatic(unittest.TestCase):
...
@@ -3813,12 +3784,12 @@ class TestSundryAPIStatic(unittest.TestCase):
x1
=
paddle
.
rand
([])
x1
=
paddle
.
rand
([])
x1
.
stop_gradient
=
False
x1
.
stop_gradient
=
False
out1
=
paddle
.
quantile
(
x1
,
0.5
,
axis
=
None
)
out1
=
paddle
.
quantile
(
x1
,
0.5
,
axis
=
None
)
paddle
.
static
.
append_backward
(
out1
.
sum
()
)
paddle
.
static
.
append_backward
(
out1
)
x2
=
paddle
.
rand
([
2
,
3
])
x2
=
paddle
.
rand
([
2
,
3
])
x2
.
stop_gradient
=
False
x2
.
stop_gradient
=
False
out2
=
paddle
.
quantile
(
x2
,
0.5
,
axis
=
None
)
out2
=
paddle
.
quantile
(
x2
,
0.5
,
axis
=
None
)
paddle
.
static
.
append_backward
(
out2
.
sum
()
)
paddle
.
static
.
append_backward
(
out2
)
out_empty_list
=
paddle
.
quantile
(
x1
,
0.5
,
axis
=
[])
out_empty_list
=
paddle
.
quantile
(
x1
,
0.5
,
axis
=
[])
self
.
assertEqual
(
out_empty_list
.
shape
,
())
self
.
assertEqual
(
out_empty_list
.
shape
,
())
...
@@ -3846,6 +3817,37 @@ class TestSundryAPIStatic(unittest.TestCase):
...
@@ -3846,6 +3817,37 @@ class TestSundryAPIStatic(unittest.TestCase):
self
.
assertEqual
(
res
[
5
].
shape
,
())
self
.
assertEqual
(
res
[
5
].
shape
,
())
self
.
assertEqual
(
res
[
5
],
1.0
)
self
.
assertEqual
(
res
[
5
],
1.0
)
@
prog_scope
()
def
test_nanquantile
(
self
):
# 1) x is 0D
x1
=
paddle
.
rand
([])
x1
.
stop_gradient
=
False
out1
=
paddle
.
nanquantile
(
x1
,
0.5
,
axis
=
None
)
paddle
.
static
.
append_backward
(
out1
)
# 2) x is ND with 'nan'
x2
=
paddle
.
to_tensor
([[
float
(
'nan'
),
2.0
,
3.0
],
[
0.0
,
1.0
,
2.0
]])
x2
.
stop_gradient
=
False
out2
=
paddle
.
nanquantile
(
x2
,
0.5
,
axis
=
None
)
print
(
out2
)
paddle
.
static
.
append_backward
(
out2
)
prog
=
paddle
.
static
.
default_main_program
()
res
=
self
.
exe
.
run
(
prog
,
fetch_list
=
[
out1
,
x1
.
grad_name
,
out2
,
x2
.
grad_name
,
],
)
self
.
assertEqual
(
res
[
0
].
shape
,
())
self
.
assertEqual
(
res
[
1
].
shape
,
())
self
.
assertEqual
(
res
[
2
].
shape
,
())
self
.
assertEqual
(
res
[
3
].
shape
,
(
2
,
3
))
@
prog_scope
()
@
prog_scope
()
def
test_flip
(
self
):
def
test_flip
(
self
):
x
=
paddle
.
rand
([])
x
=
paddle
.
rand
([])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录