Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
643a268e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
643a268e
编写于
12月 21, 2021
作者:
S
sneaxiy
提交者:
GitHub
12月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support FP16 mean (#38289)
* mean first version * fix scalar mean * add fp16 dtype for api
上级
c197d73b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
205 addition
and
63 deletion
+205
-63
paddle/fluid/operators/kernel_primitives/functor_primitives.h
...le/fluid/operators/kernel_primitives/functor_primitives.h
+11
-4
paddle/fluid/operators/mean_op.cu
paddle/fluid/operators/mean_op.cu
+24
-29
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
+1
-1
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
+2
-0
paddle/fluid/operators/reduce_ops/reduce_mean_op.h
paddle/fluid/operators/reduce_ops/reduce_mean_op.h
+13
-0
paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
+6
-0
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
+38
-19
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+5
-3
python/paddle/fluid/tests/unittests/test_mean_op.py
python/paddle/fluid/tests/unittests/test_mean_op.py
+103
-6
python/paddle/tensor/stat.py
python/paddle/tensor/stat.py
+2
-1
未找到文件。
paddle/fluid/operators/kernel_primitives/functor_primitives.h
浏览文件 @
643a268e
...
...
@@ -14,7 +14,10 @@
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -74,16 +77,20 @@ struct IdentityFunctor {
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
DivideFunctor
{
HOSTDEVICE
inline
DivideFunctor
()
{
n_inv
=
static_cast
<
Tx
>
(
1.0
f
);
}
private:
using
MPType
=
typename
::
paddle
::
operators
::
details
::
MPTypeTrait
<
Tx
>::
Type
;
public:
HOSTDEVICE
inline
DivideFunctor
()
{
n_inv
=
static_cast
<
MPType
>
(
1.0
f
);
}
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
Tx
)(
1.0
/
n
))
{}
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
MPType
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
*
n_inv
);
return
static_cast
<
Ty
>
(
static_cast
<
MPType
>
(
x
)
*
n_inv
);
}
private:
Tx
n_inv
;
MPType
n_inv
;
};
/**
...
...
paddle/fluid/operators/mean_op.cu
浏览文件 @
643a268e
...
...
@@ -18,30 +18,23 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace
cub
=
hipcub
;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
DivideFunctor
{
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
(
static_cast
<
T
>
(
1.0
/
n
))
{}
HOSTDEVICE
inline
T
operator
()(
const
T
&
x
)
const
{
return
x
*
n_inv
;
}
private:
T
n_inv
;
};
template
<
typename
T
>
__global__
void
MeanRunKernel
(
const
T
*
in_data
,
T
*
out_data
,
int
N
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
T
data
=
in_data
[
0
]
;
auto
data
=
static_cast
<
MT
>
(
in_data
[
0
])
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
out_data
[
idx
]
=
data
/
(
static_cast
<
T
>
(
N
));
out_data
[
idx
]
=
static_cast
<
T
>
(
data
/
(
static_cast
<
MT
>
(
N
)
));
}
}
...
...
@@ -52,27 +45,29 @@ class MeanCUDAKernel : public framework::OpKernel<T> {
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
size_prob
=
input
->
numel
();
const
T
*
in_data
=
input
->
data
<
T
>
();
T
*
out_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
numel
=
input
->
numel
();
auto
rank
=
input
->
dims
().
size
();
auto
place
=
context
.
GetPlace
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
DivideFunctor
<
T
>
transformer
(
size_prob
);
cub
::
TransformInputIterator
<
T
,
DivideFunctor
<
T
>
,
const
T
*>
trans_x
(
in_data
,
transformer
);
size_t
temp_storage_bytes
=
0
;
if
(
rank
==
0
)
{
// scalar
auto
gpu_place
=
BOOST_GET
(
platform
::
CUDAPlace
,
place
);
memory
::
Copy
(
gpu_place
,
out_data
,
gpu_place
,
in_data
,
numel
*
sizeof
(
T
),
stream
);
return
;
}
auto
err
=
cub
::
DeviceReduce
::
Sum
(
nullptr
,
temp_storage_bytes
,
trans_x
,
out_data
,
size_prob
,
stream
);
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
context
.
GetPlace
());
err
=
cub
::
DeviceReduce
::
Sum
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
out_data
,
size_prob
,
stream
);
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
Div
=
kernel_primitives
::
DivideFunctor
<
T
,
MT
>
;
std
::
vector
<
int
>
reduce_dims
;
reduce_dims
.
reserve
(
rank
);
for
(
decltype
(
rank
)
i
=
0
;
i
<
rank
;
++
i
)
{
reduce_dims
.
push_back
(
i
);
}
TensorReduceFunctorImpl
<
T
,
T
,
kernel_primitives
::
AddFunctor
,
Div
>
(
*
input
,
output
,
Div
(
numel
),
reduce_dims
,
stream
);
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_functor_op.h
浏览文件 @
643a268e
...
...
@@ -77,7 +77,7 @@ struct CustomSub {
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
CustomMean
{
using
Transformer
=
kps
::
DivideFunctor
<
Tx
>
;
using
Transformer
=
kps
::
DivideFunctor
<
Tx
,
Ty
>
;
inline
Ty
initial
()
{
return
static_cast
<
Ty
>
(
0.0
f
);
}
...
...
paddle/fluid/operators/reduce_ops/reduce_mean_op.cu
浏览文件 @
643a268e
...
...
@@ -19,5 +19,7 @@
REGISTER_OP_CUDA_KERNEL
(
reduce_mean
,
ops
::
ReduceCudaKernel
<
bool
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
,
ops
::
ReduceCudaKernel
<
paddle
::
platform
::
float16
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
,
ops
::
ReduceCudaKernel
<
float
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
,
ops
::
ReduceCudaKernel
<
double
,
kps
::
AddFunctor
,
kps
::
DivideFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_mean_op.h
浏览文件 @
643a268e
...
...
@@ -35,5 +35,18 @@ struct MeanGradFunctor {
}
};
// TODO(zengjinle): Should refine the numeric stability of FP16 reduce_mean
// and reduce_mean_grad later.
struct
FP16MeanGradFunctor
{
template
<
typename
DeviceContext
,
typename
X
,
typename
Y
,
typename
DX
,
typename
DY
,
typename
Dim
>
void
operator
()(
const
DeviceContext
&
place
,
X
*
x
,
Y
*
y
,
DX
*
dx
,
DY
*
dy
,
const
Dim
&
dim
,
int
size
)
{
dx
->
device
(
place
)
=
(
dy
->
template
cast
<
float
>().
broadcast
(
dim
)
/
dx
->
template
cast
<
float
>().
constant
(
size
))
.
template
cast
<
platform
::
float16
>();
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu
浏览文件 @
643a268e
...
...
@@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel =
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
T
,
ops
::
MeanGradFunctor
,
true
>
;
using
FP16CUDAReduceMeanGradKernel
=
ops
::
ReduceGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
,
ops
::
FP16MeanGradFunctor
,
true
>
;
REGISTER_OP_CUDA_KERNEL
(
reduce_mean_grad
,
CUDAReduceMeanGradKernel
<
bool
>
,
FP16CUDAReduceMeanGradKernel
,
CUDAReduceMeanGradKernel
<
float
>
,
CUDAReduceMeanGradKernel
<
double
>
);
paddle/fluid/operators/reduce_ops/reduce_op.cu.h
浏览文件 @
643a268e
...
...
@@ -38,7 +38,9 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/string/string_helper.h"
// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
...
...
@@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
}
}
template
<
typename
Tx
,
typename
Ty
,
template
<
typename
>
class
ReduceOp
,
typename
TransformOp
>
static
typename
std
::
enable_if
<!
std
::
is_same
<
Tx
,
platform
::
float16
>::
value
,
void
>::
type
CubTensorReduceFunctorImpl
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
TransformOp
&
transform
,
int
reduce_num
,
const
platform
::
Place
&
place
,
gpuStream_t
stream
)
{
auto
reducer
=
ReduceOp
<
Ty
>
();
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transform
);
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
reducer
.
initial
(),
stream
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
place
);
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
reduce_num
,
reducer
,
reducer
.
initial
(),
stream
);
}
template
<
typename
Tx
,
typename
Ty
,
template
<
typename
>
class
ReduceOp
,
typename
TransformOp
>
static
typename
std
::
enable_if
<
std
::
is_same
<
Tx
,
platform
::
float16
>::
value
,
void
>::
type
CubTensorReduceFunctorImpl
(
const
Tx
*
x_data
,
Ty
*
y_data
,
const
TransformOp
&
transform
,
int
reduce_num
,
const
platform
::
Place
&
place
,
gpuStream_t
stream
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."
));
}
template
<
typename
Tx
,
typename
Ty
,
template
<
typename
>
class
ReduceOp
,
typename
TransformOp
>
void
TensorReduceFunctorImpl
(
const
framework
::
Tensor
&
x
,
framework
::
Tensor
*
y
,
const
TransformOp
&
transform
,
std
::
vector
<
int
>
origin_reduce_dims
,
const
std
::
vector
<
int
>&
origin_reduce_dims
,
gpuStream_t
stream
)
{
auto
x_dim
=
framework
::
vectorize
<
int
>
(
x
.
dims
());
auto
config
=
ReduceConfig
<
Ty
>
(
origin_reduce_dims
,
x_dim
);
...
...
@@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
}
config
.
SetOutputData
(
y_data
,
x
.
place
(),
&
tmp
);
bool
use_cub_reduce
=
(
config
.
reduce_num
==
numel
)
&&
(
!
std
::
is_same
<
Tx
,
paddle
::
platform
::
float16
>::
value
)
;
constexpr
bool
kIsTxFP16
=
std
::
is_same
<
Tx
,
paddle
::
platform
::
float16
>::
value
;
bool
use_cub_reduce
=
config
.
reduce_num
==
numel
&&
!
kIsTxFP16
;
if
(
use_cub_reduce
)
{
// launch CUB::Reduce
auto
reducer
=
ReduceOp
<
Ty
>
();
cub
::
TransformInputIterator
<
Ty
,
TransformOp
,
const
Tx
*>
trans_x
(
x_data
,
transform
);
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceReduce
::
Reduce
(
nullptr
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
reducer
.
initial
(),
stream
);
framework
::
Tensor
tmp
;
auto
*
temp_storage
=
tmp
.
mutable_data
<
uint8_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
temp_storage_bytes
)}),
x
.
place
());
cub
::
DeviceReduce
::
Reduce
(
temp_storage
,
temp_storage_bytes
,
trans_x
,
y_data
,
config
.
reduce_num
,
reducer
,
reducer
.
initial
(),
stream
);
CubTensorReduceFunctorImpl
<
Tx
,
Ty
,
ReduceOp
,
TransformOp
>
(
x_data
,
y_data
,
transform
,
config
.
reduce_num
,
x
.
place
(),
stream
);
return
;
}
...
...
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
643a268e
...
...
@@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
reduce_dims
=
GetReduceDim
(
dims
,
input
->
dims
().
size
(),
reduce_all
);
int
reduce_num
=
1
;
for
(
int
i
=
0
;
i
<
input
->
dims
().
size
();
i
++
)
{
for
(
auto
i
:
reduce_dims
)
{
reduce_num
*=
(
input
->
dims
())[
i
];
}
gpuStream_t
stream
=
context
.
cuda_device_context
().
stream
();
...
...
@@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel<T> {
TensorReduceFunc
<
T
,
ReduceOp
,
TransformOp
>
(
*
input
,
output
,
reduce_dims
,
reduce_num
,
stream
));
}
else
{
TensorReduceFunctorImpl
<
T
,
T
,
ReduceOp
,
TransformOp
<
T
,
T
>>
(
*
input
,
output
,
TransformOp
<
T
,
T
>
(
reduce_num
),
reduce_dims
,
stream
);
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
TensorReduceFunctorImpl
<
T
,
T
,
ReduceOp
,
TransformOp
<
T
,
MPType
>>
(
*
input
,
output
,
TransformOp
<
T
,
MPType
>
(
reduce_num
),
reduce_dims
,
stream
);
}
}
};
...
...
python/paddle/fluid/tests/unittests/test_mean_op.py
浏览文件 @
643a268e
...
...
@@ -63,17 +63,25 @@ class TestMeanOpError(unittest.TestCase):
class
TestFP16MeanOp
(
TestMeanOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
self
.
__class__
.
no_need_check_grad
=
True
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-3
)
self
.
check_output_with_place
(
place
)
def
test_checkout_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
max_relative_error
=
0.8
)
with
fluid
.
dygraph
.
guard
():
x_np
=
np
.
random
.
random
((
10
,
10
)).
astype
(
self
.
dtype
)
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
fluid
.
layers
.
mean
(
x
)
dx
=
paddle
.
grad
(
y
,
x
)[
0
].
numpy
()
dx_expected
=
self
.
dtype
(
1.0
/
np
.
prod
(
x_np
.
shape
))
*
np
.
ones
(
x_np
.
shape
).
astype
(
self
.
dtype
)
self
.
assertTrue
(
np
.
array_equal
(
dx
,
dx_expected
))
@
OpTestTool
.
skip_if_not_cpu_bf16
()
...
...
@@ -98,6 +106,14 @@ def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
return
np
.
mean
(
x
,
axis
=
axis
,
keepdims
=
keepdim
)
def
ref_reduce_mean_grad
(
x
,
axis
,
dtype
):
if
reduce_all
:
axis
=
list
(
range
(
x
.
ndim
))
shape
=
[
x
.
shape
[
i
]
for
i
in
axis
]
return
(
1.0
/
np
.
prod
(
shape
)
*
np
.
ones
(
shape
)).
astype
(
dtype
)
class
TestReduceMeanOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'reduce_mean'
...
...
@@ -105,11 +121,13 @@ class TestReduceMeanOp(OpTest):
self
.
shape
=
[
2
,
3
,
4
,
5
]
self
.
axis
=
[
0
]
self
.
keepdim
=
False
self
.
reduce_all
=
False
self
.
set_attrs
()
np
.
random
.
seed
(
10
)
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
self
.
dtype
)
if
not
hasattr
(
self
,
"reduce_all"
):
self
.
reduce_all
=
(
not
self
.
axis
)
or
len
(
self
.
axis
)
==
len
(
x_np
)
out_np
=
ref_reduce_mean
(
x_np
,
self
.
axis
,
self
.
keepdim
,
self
.
reduce_all
)
self
.
inputs
=
{
'X'
:
x_np
}
self
.
outputs
=
{
'Out'
:
out_np
}
...
...
@@ -119,14 +137,39 @@ class TestReduceMeanOp(OpTest):
'reduce_all'
:
self
.
reduce_all
}
if
self
.
dtype
==
'float16'
:
self
.
__class__
.
no_need_check_grad
=
True
def
set_attrs
(
self
):
pass
def
test_check_output
(
self
):
self
.
check_output
()
if
self
.
dtype
!=
'float16'
:
self
.
check_output
()
else
:
if
not
core
.
is_compiled_with_cuda
():
return
place
=
paddle
.
CUDAPlace
(
0
)
self
.
check_output_with_place
(
place
=
place
)
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
[
'Out'
])
if
self
.
dtype
!=
'float16'
:
self
.
check_grad
([
'X'
],
[
'Out'
])
else
:
return
if
not
core
.
is_compiled_with_cuda
():
return
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
return
with
fluid
.
dygraph
.
guard
(
place
=
place
):
x
=
paddle
.
tensor
(
self
.
inputs
[
'X'
])
y
=
paddle
.
mean
(
x
,
axis
=
self
.
attrs
[
'dim'
],
keepdim
=
self
.
attrs
[
'keep_dim'
])
dx
=
paddle
.
grad
(
y
,
x
)[
0
].
numpy
()
dx_expected
=
ref_reduce_mean_grad
(
self
.
inputs
[
'X'
],
self
.
attrs
[
'dim'
],
self
.
dtype
)
self
.
assertTrue
(
np
.
array_equal
(
dx
,
dx_expected
))
class
TestReduceMeanOpDefaultAttrs
(
TestReduceMeanOp
):
...
...
@@ -146,47 +189,101 @@ class TestReduceMeanOpFloat32(TestReduceMeanOp):
self
.
dtype
=
'float32'
class
TestReduceMeanOpFloat16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
dtype
=
'float16'
class
TestReduceMeanOpShape1D
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
shape
=
[
100
]
class
TestReduceMeanOpShape1DFP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
shape
=
[
100
]
self
.
dtype
=
'float16'
class
TestReduceMeanOpShape6D
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
shape
=
[
2
,
3
,
4
,
5
,
6
,
7
]
class
TestReduceMeanOpShape6DFP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
shape
=
[
2
,
3
,
4
,
5
,
6
,
7
]
self
.
dtype
=
'float16'
class
TestReduceMeanOpAxisAll
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
[
0
,
1
,
2
,
3
]
class
TestReduceMeanOpAxisAllFP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
[
0
,
1
,
2
,
3
]
self
.
dtype
=
'float16'
class
TestReduceMeanOpAxisTuple
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
(
0
,
1
,
2
)
class
TestReduceMeanOpAxisTupleFP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
(
0
,
1
,
2
)
self
.
dtype
=
'float16'
class
TestReduceMeanOpAxisNegative
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
[
-
2
,
-
1
]
class
TestReduceMeanOpAxisNegativeFP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
[
-
2
,
-
1
]
self
.
dtype
=
'float16'
class
TestReduceMeanOpKeepdimTrue1
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
keepdim
=
True
class
TestReduceMeanOpKeepdimTrue1FP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
keepdim
=
True
self
.
dtype
=
'float16'
class
TestReduceMeanOpKeepdimTrue2
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
[
0
,
1
,
2
,
3
]
self
.
keepdim
=
True
class
TestReduceMeanOpKeepdimTrue2FP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
axis
=
[
0
,
1
,
2
,
3
]
self
.
keepdim
=
True
self
.
dtype
=
'float16'
class
TestReduceMeanOpReduceAllTrue
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
reduce_all
=
True
class
TestReduceMeanOpReduceAllTrueFP16
(
TestReduceMeanOp
):
def
set_attrs
(
self
):
self
.
reduce_all
=
True
self
.
dtype
=
'float16'
class
TestMeanAPI
(
unittest
.
TestCase
):
# test paddle.tensor.stat.mean
...
...
python/paddle/tensor/stat.py
浏览文件 @
643a268e
...
...
@@ -92,7 +92,8 @@ def mean(x, axis=None, keepdim=False, name=None):
return
_C_ops
.
reduce_mean
(
x
,
'dim'
,
axis
,
'keep_dim'
,
keepdim
,
'reduce_all'
,
reduce_all
)
check_variable_and_dtype
(
x
,
'x/input'
,
[
'uint16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
x
,
'x/input'
,
[
'uint16'
,
'float16'
,
'float32'
,
'float64'
],
'mean/reduce_mean'
)
check_type
(
axis
,
'axis/dim'
,
(
int
,
list
,
tuple
),
'mean/reduce_mean'
)
if
isinstance
(
axis
,
(
list
,
tuple
)):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录