Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b0dbf9fe
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b0dbf9fe
编写于
4月 04, 2023
作者:
C
chenxujun
提交者:
GitHub
4月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.62】增加pool3d算子BF16及单测,lgamma, masked_select FP16/BF16算子单测 (#51837)
* Add pool3d lgamma masked_select tests * Fix code
上级
f6f104d5
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
279 addition
and
21 deletion
+279
-21
paddle/phi/kernels/funcs/pooling.cu
paddle/phi/kernels/funcs/pooling.cu
+26
-0
paddle/phi/kernels/funcs/select_impl.cu.h
paddle/phi/kernels/funcs/select_impl.cu.h
+1
-1
paddle/phi/kernels/gpu/lgamma_grad_kernel.cu
paddle/phi/kernels/gpu/lgamma_grad_kernel.cu
+9
-2
paddle/phi/kernels/gpu/lgamma_kernel.cu
paddle/phi/kernels/gpu/lgamma_kernel.cu
+12
-2
paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
+4
-1
paddle/phi/kernels/gpu/masked_select_kernel.cu
paddle/phi/kernels/gpu/masked_select_kernel.cu
+4
-1
paddle/phi/kernels/gpu/pool_grad_kernel.cu
paddle/phi/kernels/gpu/pool_grad_kernel.cu
+3
-1
paddle/phi/kernels/gpu/pool_kernel.cu
paddle/phi/kernels/gpu/pool_kernel.cu
+3
-1
paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h
paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h
+5
-1
python/paddle/fluid/tests/unittests/test_lgamma_op.py
python/paddle/fluid/tests/unittests/test_lgamma_op.py
+37
-1
python/paddle/fluid/tests/unittests/test_masked_select_op.py
python/paddle/fluid/tests/unittests/test_masked_select_op.py
+71
-1
python/paddle/fluid/tests/unittests/test_pool3d_api.py
python/paddle/fluid/tests/unittests/test_pool3d_api.py
+33
-1
python/paddle/fluid/tests/unittests/test_pool3d_op.py
python/paddle/fluid/tests/unittests/test_pool3d_op.py
+64
-3
python/paddle/nn/functional/pooling.py
python/paddle/nn/functional/pooling.py
+1
-1
python/paddle/tensor/math.py
python/paddle/tensor/math.py
+4
-2
python/paddle/tensor/search.py
python/paddle/tensor/search.py
+2
-2
未找到文件。
paddle/phi/kernels/funcs/pooling.cu
浏览文件 @
b0dbf9fe
...
...
@@ -993,6 +993,7 @@ template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>;
template
class
MaxPool2dGradFunctor
<
phi
::
GPUContext
,
float
>;
template
class
MaxPool2dGradFunctor
<
phi
::
GPUContext
,
double
>;
template
class
MaxPool2dGradFunctor
<
phi
::
GPUContext
,
dtype
::
float16
>;
template
class
MaxPool2dGradFunctor
<
phi
::
GPUContext
,
dtype
::
bfloat16
>;
template
class
Pool2dFunctor
<
phi
::
GPUContext
,
MaxPool
<
float
>,
float
>
;
template
class
Pool2dFunctor
<
phi
::
GPUContext
,
AvgPool
<
float
>,
float
>
;
...
...
@@ -1015,6 +1016,18 @@ template class Pool2dGradFunctor<phi::GPUContext,
template
class
Pool2dGradFunctor
<
phi
::
GPUContext
,
AvgPoolGrad
<
dtype
::
float16
>,
dtype
::
float16
>
;
template
class
Pool2dFunctor
<
phi
::
GPUContext
,
MaxPool
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
class
Pool2dFunctor
<
phi
::
GPUContext
,
AvgPool
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
class
Pool2dGradFunctor
<
phi
::
GPUContext
,
MaxPoolGrad
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
class
Pool2dGradFunctor
<
phi
::
GPUContext
,
AvgPoolGrad
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
(
const
int
nthreads
,
...
...
@@ -1863,6 +1876,7 @@ template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>;
template
class
MaxPool3dGradFunctor
<
phi
::
GPUContext
,
float
>;
template
class
MaxPool3dGradFunctor
<
phi
::
GPUContext
,
double
>;
template
class
MaxPool3dGradFunctor
<
phi
::
GPUContext
,
dtype
::
float16
>;
template
class
MaxPool3dGradFunctor
<
phi
::
GPUContext
,
dtype
::
bfloat16
>;
template
class
Pool3dFunctor
<
phi
::
GPUContext
,
MaxPool
<
float
>,
float
>
;
template
class
Pool3dFunctor
<
phi
::
GPUContext
,
AvgPool
<
float
>,
float
>
;
...
...
@@ -1879,12 +1893,24 @@ template class Pool3dFunctor<phi::GPUContext,
template
class
Pool3dFunctor
<
phi
::
GPUContext
,
AvgPool
<
dtype
::
float16
>,
dtype
::
float16
>
;
template
class
Pool3dFunctor
<
phi
::
GPUContext
,
MaxPool
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
class
Pool3dFunctor
<
phi
::
GPUContext
,
AvgPool
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
class
Pool3dGradFunctor
<
phi
::
GPUContext
,
MaxPoolGrad
<
dtype
::
float16
>,
dtype
::
float16
>
;
template
class
Pool3dGradFunctor
<
phi
::
GPUContext
,
AvgPoolGrad
<
dtype
::
float16
>,
dtype
::
float16
>
;
template
class
Pool3dGradFunctor
<
phi
::
GPUContext
,
MaxPoolGrad
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
class
Pool3dGradFunctor
<
phi
::
GPUContext
,
AvgPoolGrad
<
dtype
::
bfloat16
>,
dtype
::
bfloat16
>
;
template
<
typename
T1
,
typename
T2
>
__global__
void
KernelMaxPool2dWithIdx
(
const
int
nthreads
,
...
...
paddle/phi/kernels/funcs/select_impl.cu.h
浏览文件 @
b0dbf9fe
...
...
@@ -268,7 +268,7 @@ __device__ void SelectKernelImpl(OutT *out,
using
IdT
=
int64_t
;
// Set index data type
using
Add
=
kps
::
AddFunctor
<
IdT
>
;
// for cumsum
using
Cast
=
NonZeroFunctor
<
InT
>
;
// for mask
using
Cast
=
NonZeroFunctor
<
MT
>
;
// for mask
IdT
init_idx
=
static_cast
<
IdT
>
(
0.0
f
);
MT
init_mask
=
static_cast
<
MT
>
(
0.0
f
);
...
...
paddle/phi/kernels/gpu/lgamma_grad_kernel.cu
浏览文件 @
b0dbf9fe
...
...
@@ -15,7 +15,14 @@
#include "paddle/phi/kernels/lgamma_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h"
PD_REGISTER_KERNEL
(
lgamma_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
LgammaGradKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
lgamma_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
LgammaGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/lgamma_kernel.cu
浏览文件 @
b0dbf9fe
...
...
@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/lgamma_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
...
...
@@ -22,7 +23,9 @@ namespace phi {
template
<
typename
T
>
struct
CudaLgammaFunctor
{
__device__
__forceinline__
T
operator
()(
const
T
x
)
const
{
return
Eigen
::
numext
::
lgamma
(
x
);
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
const
MT
mp_x
=
static_cast
<
MT
>
(
x
);
return
static_cast
<
T
>
(
Eigen
::
numext
::
lgamma
(
mp_x
));
}
};
template
<
typename
T
,
typename
Context
>
...
...
@@ -38,4 +41,11 @@ void LgammaKernel(const Context& dev_ctx,
}
}
// namespace phi
PD_REGISTER_KERNEL
(
lgamma
,
GPU
,
ALL_LAYOUT
,
phi
::
LgammaKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
lgamma
,
GPU
,
ALL_LAYOUT
,
phi
::
LgammaKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/masked_select_grad_kernel.cu
浏览文件 @
b0dbf9fe
...
...
@@ -19,6 +19,7 @@
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
...
...
@@ -66,4 +67,6 @@ PD_REGISTER_KERNEL(masked_select_grad,
float
,
double
,
int
,
int64_t
)
{}
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/gpu/masked_select_kernel.cu
浏览文件 @
b0dbf9fe
...
...
@@ -20,6 +20,7 @@
#include <thrust/scan.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/select_impl.cu.h"
...
...
@@ -76,6 +77,8 @@ PD_REGISTER_KERNEL(masked_select,
float
,
double
,
int
,
int64_t
)
{
int64_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{
kernel
->
InputAt
(
1
).
SetDataType
(
phi
::
DataType
::
BOOL
);
}
paddle/phi/kernels/gpu/pool_grad_kernel.cu
浏览文件 @
b0dbf9fe
...
...
@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/pool_grad_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h"
...
...
@@ -46,7 +47,8 @@ PD_REGISTER_KERNEL(pool3d_grad,
phi
::
Pool3dGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
max_pool3d_with_index_grad
,
GPU
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/gpu/pool_kernel.cu
浏览文件 @
b0dbf9fe
...
...
@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/pool_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pool_kernel_impl.h"
...
...
@@ -40,7 +41,8 @@ PD_REGISTER_KERNEL(pool3d,
phi
::
Pool3dKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
max_pool3d_with_index
,
GPU
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h
浏览文件 @
b0dbf9fe
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <unsupported/Eigen/SpecialFunctions>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace
phi
{
template
<
typename
T
>
...
...
@@ -23,7 +24,10 @@ struct LgammaGradFunctor {
:
dout_
(
dout
),
x_
(
x
),
output_
(
output
),
numel_
(
numel
)
{}
HOSTDEVICE
void
operator
()(
int64_t
idx
)
const
{
output_
[
idx
]
=
dout_
[
idx
]
*
Eigen
::
numext
::
digamma
(
x_
[
idx
]);
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
const
MT
mp_dout
=
static_cast
<
MT
>
(
dout_
[
idx
]);
const
MT
mp_x
=
static_cast
<
MT
>
(
x_
[
idx
]);
output_
[
idx
]
=
static_cast
<
T
>
(
mp_dout
*
Eigen
::
numext
::
digamma
(
mp_x
));
}
private:
...
...
python/paddle/fluid/tests/unittests/test_lgamma_op.py
浏览文件 @
b0dbf9fe
...
...
@@ -16,10 +16,11 @@ import math
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
from
scipy
import
special
import
paddle
from
paddle.fluid
import
core
paddle
.
enable_static
()
...
...
@@ -56,6 +57,41 @@ class TestLgammaOpFp32(TestLgammaOp):
self
.
check_grad
([
'X'
],
'Out'
,
numeric_grad_delta
=
0.005
)
class
TestLgammaFP16Op
(
TestLgammaOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestLgammaBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'lgamma'
self
.
python_api
=
paddle
.
lgamma
self
.
dtype
=
np
.
uint16
shape
=
(
5
,
20
)
data
=
np
.
random
.
random
(
shape
).
astype
(
"float32"
)
+
1
self
.
inputs
=
{
'X'
:
convert_float_to_uint16
(
data
)}
result
=
np
.
ones
(
shape
).
astype
(
"float32"
)
for
i
in
range
(
shape
[
0
]):
for
j
in
range
(
shape
[
1
]):
result
[
i
][
j
]
=
math
.
lgamma
(
data
[
i
][
j
])
self
.
outputs
=
{
'Out'
:
convert_float_to_uint16
(
result
)}
def
test_check_output
(
self
):
# After testing, bfloat16 needs to set the parameter place
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
))
def
test_check_grad_normal
(
self
):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
'X'
],
'Out'
)
class
TestLgammaOpApi
(
unittest
.
TestCase
):
def
test_lgamma
(
self
):
paddle
.
disable_static
()
...
...
python/paddle/fluid/tests/unittests/test_masked_select_op.py
浏览文件 @
b0dbf9fe
...
...
@@ -15,9 +15,10 @@
import
unittest
import
numpy
as
np
from
eager_op_test
import
OpTest
from
eager_op_test
import
OpTest
,
convert_float_to_uint16
import
paddle
from
paddle.fluid
import
core
def
np_masked_select
(
x
,
mask
):
...
...
@@ -59,6 +60,75 @@ class TestMaskedSelectOp2(TestMaskedSelectOp):
self
.
shape
=
(
168
,)
class
TestMaskedSelectFP16Op
(
OpTest
):
def
setUp
(
self
):
self
.
init
()
self
.
op_type
=
"masked_select"
self
.
dtype
=
np
.
float16
self
.
python_api
=
paddle
.
masked_select
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float16"
)
mask
=
np
.
array
(
np
.
random
.
randint
(
2
,
size
=
self
.
shape
,
dtype
=
bool
))
out
=
np_masked_select
(
x
,
mask
)
self
.
inputs
=
{
'X'
:
x
,
'Mask'
:
mask
}
self
.
outputs
=
{
'Y'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
def
init
(
self
):
self
.
shape
=
(
50
,
3
)
class
TestMaskedSelectFP16Op1
(
TestMaskedSelectFP16Op
):
def
init
(
self
):
self
.
shape
=
(
6
,
8
,
9
,
18
)
class
TestMaskedSelectFP16Op2
(
TestMaskedSelectFP16Op
):
def
init
(
self
):
self
.
shape
=
(
168
,)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not compiled with CUDA or not support bfloat16"
,
)
class
TestMaskedSelectBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
init
()
self
.
op_type
=
"masked_select"
self
.
dtype
=
np
.
uint16
self
.
python_api
=
paddle
.
masked_select
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)
mask
=
np
.
array
(
np
.
random
.
randint
(
2
,
size
=
self
.
shape
,
dtype
=
bool
))
out
=
np_masked_select
(
x
,
mask
)
self
.
inputs
=
{
'X'
:
convert_float_to_uint16
(
x
),
'Mask'
:
mask
}
self
.
outputs
=
{
'Y'
:
convert_float_to_uint16
(
out
)}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
))
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
'X'
],
'Y'
)
def
init
(
self
):
self
.
shape
=
(
50
,
3
)
class
TestMaskedSelectBF16Op1
(
TestMaskedSelectBF16Op
):
def
init
(
self
):
self
.
shape
=
(
6
,
8
,
9
,
2
)
class
TestMaskedSelectBF16Op2
(
TestMaskedSelectBF16Op
):
def
init
(
self
):
self
.
shape
=
(
168
,)
class
TestMaskedSelectAPI
(
unittest
.
TestCase
):
def
test_imperative_mode
(
self
):
paddle
.
disable_static
()
...
...
python/paddle/fluid/tests/unittests/test_pool3d_api.py
浏览文件 @
b0dbf9fe
...
...
@@ -354,6 +354,7 @@ class TestPool3D_API(unittest.TestCase):
np
.
testing
.
assert_allclose
(
result
.
numpy
(),
result_np
,
rtol
=
1e-05
)
def
test_pool3d
(
self
):
paddle
.
enable_static
()
for
place
in
self
.
places
:
self
.
check_max_dygraph_results
(
place
)
...
...
@@ -366,7 +367,8 @@ class TestPool3D_API(unittest.TestCase):
self
.
check_max_dygraph_ndhwc_results
(
place
)
self
.
check_max_dygraph_ceilmode_results
(
place
)
def
test_static_pf16_gpu
(
self
):
def
test_static_fp16_gpu
(
self
):
paddle
.
enable_static
()
if
paddle
.
fluid
.
core
.
is_compiled_with_cuda
():
place
=
paddle
.
CUDAPlace
(
0
)
with
paddle
.
static
.
program_guard
(
...
...
@@ -392,6 +394,36 @@ class TestPool3D_API(unittest.TestCase):
assert
np
.
array_equal
(
res
[
0
].
shape
,
[
1
,
2
,
1
,
16
,
16
])
def
test_static_bf16_gpu
(
self
):
paddle
.
enable_static
()
if
(
paddle
.
fluid
.
core
.
is_compiled_with_cuda
()
and
paddle
.
fluid
.
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
))
):
place
=
paddle
.
CUDAPlace
(
0
)
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
):
input
=
np
.
random
.
random
([
1
,
2
,
3
,
32
,
32
]).
astype
(
np
.
uint16
)
x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
1
,
2
,
3
,
32
,
32
],
dtype
=
"bfloat16"
)
m
=
paddle
.
nn
.
AvgPool3D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
y
=
m
(
x
)
exe
=
paddle
.
static
.
Executor
(
place
)
res
=
exe
.
run
(
paddle
.
static
.
default_main_program
(),
feed
=
{
"x"
:
input
,
},
fetch_list
=
[
y
],
)
assert
np
.
array_equal
(
res
[
0
].
shape
,
[
1
,
2
,
1
,
16
,
16
])
class
TestPool3DError_API
(
unittest
.
TestCase
):
def
test_error_api
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_pool3d_op.py
浏览文件 @
b0dbf9fe
...
...
@@ -399,9 +399,9 @@ class TestPool3D_Op(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
return
if
self
.
has_cudnn
(
)
and
self
.
pool_type
!=
"max"
:
if
(
self
.
has_cudnn
()
or
self
.
dtype
==
np
.
uint16
)
and
self
.
pool_type
!=
"max"
:
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_rocm
():
self
.
check_grad_with_place
(
...
...
@@ -566,6 +566,46 @@ def create_test_fp16_class(parent):
globals
()[
cls_name
]
=
TestFp16Case
def
create_test_cudnn_bf16_class
(
parent
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestCUDNNBf16Case
(
parent
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
uint16
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
)
cls_name
=
"{}_{}"
.
format
(
parent
.
__name__
,
"CUDNNBf16Op"
)
TestCUDNNBf16Case
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestCUDNNBf16Case
def
create_test_bf16_class
(
parent
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
not
core
.
is_bfloat16_supported
(
core
.
CUDAPlace
(
0
)),
"core is not complied with CUDA and not support the bfloat16"
,
)
class
TestBf16Case
(
parent
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
False
self
.
dtype
=
np
.
uint16
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
)
cls_name
=
"{}_{}"
.
format
(
parent
.
__name__
,
"Bf16Op"
)
TestBf16Case
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestBf16Case
create_test_cudnn_fp16_class
(
TestPool3D_Op
)
create_test_cudnn_fp16_class
(
TestCase1
)
create_test_cudnn_fp16_class
(
TestCase2
)
...
...
@@ -580,6 +620,20 @@ create_test_fp16_class(TestCase3)
create_test_fp16_class
(
TestCase4
)
create_test_fp16_class
(
TestCase5
)
create_test_cudnn_bf16_class
(
TestPool3D_Op
)
create_test_cudnn_bf16_class
(
TestCase1
)
create_test_cudnn_bf16_class
(
TestCase2
)
create_test_cudnn_bf16_class
(
TestCase3
)
create_test_cudnn_bf16_class
(
TestCase4
)
create_test_cudnn_bf16_class
(
TestCase5
)
create_test_bf16_class
(
TestPool3D_Op
)
create_test_bf16_class
(
TestCase1
)
create_test_bf16_class
(
TestCase2
)
create_test_bf16_class
(
TestCase3
)
create_test_bf16_class
(
TestCase4
)
create_test_bf16_class
(
TestCase5
)
# ---- test ceil mode ------
def
create_test_cudnn_use_ceil_class
(
parent
):
...
...
@@ -736,6 +790,13 @@ create_test_cudnn_fp16_class(TestCase3_AsyPadding)
create_test_cudnn_fp16_class
(
TestCase4_AsyPadding
)
create_test_cudnn_fp16_class
(
TestCase5_AsyPadding
)
create_test_cudnn_bf16_class
(
TestPool3D_Op_AsyPadding
)
create_test_cudnn_bf16_class
(
TestCase1_AsyPadding
)
create_test_cudnn_bf16_class
(
TestCase2_AsyPadding
)
create_test_cudnn_bf16_class
(
TestCase3_AsyPadding
)
create_test_cudnn_bf16_class
(
TestCase4_AsyPadding
)
create_test_cudnn_bf16_class
(
TestCase5_AsyPadding
)
create_test_cudnn_use_ceil_class
(
TestPool3D_Op_AsyPadding
)
create_test_cudnn_use_ceil_class
(
TestCase1_AsyPadding
)
...
...
python/paddle/nn/functional/pooling.py
浏览文件 @
b0dbf9fe
...
...
@@ -520,7 +520,7 @@ def avg_pool3d(
op_type
=
"pool3d"
helper
=
LayerHelper
(
op_type
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'avg_pool3d'
x
,
'x'
,
[
'float16'
,
'
uint16'
,
'
float32'
,
'float64'
],
'avg_pool3d'
)
dtype
=
helper
.
input_dtype
(
input_param_name
=
'x'
)
pool_out
=
helper
.
create_variable_for_type_inference
(
dtype
)
...
...
python/paddle/tensor/math.py
浏览文件 @
b0dbf9fe
...
...
@@ -4027,7 +4027,7 @@ def lgamma(x, name=None):
Args:
x (Tensor): Input Tensor. Must be one of the following types: float
32, float64
.
x (Tensor): Input Tensor. Must be one of the following types: float
16, float32, float64, uint16
.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
...
...
@@ -4046,7 +4046,9 @@ def lgamma(x, name=None):
if
in_dygraph_mode
():
return
_C_ops
.
lgamma
(
x
)
else
:
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
],
'lgamma'
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'lgamma'
)
helper
=
LayerHelper
(
'lgamma'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'lgamma'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
})
...
...
python/paddle/tensor/search.py
浏览文件 @
b0dbf9fe
...
...
@@ -807,7 +807,7 @@ def masked_select(x, mask, name=None):
which is a tensor with data type of bool.
Args:
x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64.
x (Tensor): The input Tensor, the data type can be int32, int64,
uint16, float16,
float32, float64.
mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...
...
@@ -838,7 +838,7 @@ def masked_select(x, mask, name=None):
check_variable_and_dtype
(
x
,
'x'
,
[
'float
32'
,
'float64'
,
'int32'
,
'int64
'
],
[
'float
16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16
'
],
'paddle.tensor.search.mask_select'
,
)
check_variable_and_dtype
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录