Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
25e070ec
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25e070ec
编写于
11月 07, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'ups/develop' into fea/jit/vadd
上级
cb4083b9
ea8984c9
变更
35
显示空白变更内容
内联
并排
Showing
35 changed file
with
800 addition
and
981 deletion
+800
-981
paddle/fluid/framework/threadpool.cc
paddle/fluid/framework/threadpool.cc
+18
-12
paddle/fluid/framework/threadpool.h
paddle/fluid/framework/threadpool.h
+8
-3
paddle/fluid/operators/activation_op.cu
paddle/fluid/operators/activation_op.cu
+3
-1
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+2
-3
paddle/fluid/operators/batch_norm_op.cu.cc
paddle/fluid/operators/batch_norm_op.cu.cc
+12
-9
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+4
-1
paddle/fluid/operators/cross_entropy_op.cu
paddle/fluid/operators/cross_entropy_op.cu
+9
-4
paddle/fluid/operators/elementwise_add_op.cu
paddle/fluid/operators/elementwise_add_op.cu
+2
-1
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+2
-2
paddle/fluid/operators/math/cross_entropy.cu
paddle/fluid/operators/math/cross_entropy.cu
+16
-6
paddle/fluid/operators/math/cross_entropy.h
paddle/fluid/operators/math/cross_entropy.h
+21
-0
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+12
-3
paddle/fluid/operators/math/softmax.cu
paddle/fluid/operators/math/softmax.cu
+3
-0
paddle/fluid/operators/mean_op.cu
paddle/fluid/operators/mean_op.cu
+6
-2
paddle/fluid/operators/mean_op.h
paddle/fluid/operators/mean_op.h
+1
-2
paddle/fluid/operators/mul_op.cu.cc
paddle/fluid/operators/mul_op.cu.cc
+4
-3
paddle/fluid/operators/pool_cudnn_op.cu.cc
paddle/fluid/operators/pool_cudnn_op.cu.cc
+2
-1
paddle/fluid/operators/scale_op.cu
paddle/fluid/operators/scale_op.cu
+5
-1
paddle/fluid/operators/softmax_cudnn_op.cu.cc
paddle/fluid/operators/softmax_cudnn_op.cu.cc
+2
-1
paddle/fluid/operators/softmax_op.cu.cc
paddle/fluid/operators/softmax_op.cu.cc
+2
-1
paddle/fluid/operators/sum_op.cu
paddle/fluid/operators/sum_op.cu
+4
-1
paddle/fluid/operators/sum_op.h
paddle/fluid/operators/sum_op.h
+1
-1
python/paddle/fluid/io.py
python/paddle/fluid/io.py
+7
-2
python/paddle/fluid/recordio_writer.py
python/paddle/fluid/recordio_writer.py
+0
-3
python/paddle/fluid/tests/unittests/op_test.py
python/paddle/fluid/tests/unittests/op_test.py
+9
-8
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+90
-537
python/paddle/fluid/tests/unittests/test_conv2d_op.py
python/paddle/fluid/tests/unittests/test_conv2d_op.py
+63
-88
python/paddle/fluid/tests/unittests/test_cross_entropy_op.py
python/paddle/fluid/tests/unittests/test_cross_entropy_op.py
+201
-137
python/paddle/fluid/tests/unittests/test_mean_op.py
python/paddle/fluid/tests/unittests/test_mean_op.py
+25
-1
python/paddle/fluid/tests/unittests/test_mul_op.py
python/paddle/fluid/tests/unittests/test_mul_op.py
+78
-32
python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py
python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py
+2
-2
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+74
-100
python/paddle/fluid/tests/unittests/test_scale_op.py
python/paddle/fluid/tests/unittests/test_scale_op.py
+52
-3
python/paddle/fluid/tests/unittests/test_softmax_op.py
python/paddle/fluid/tests/unittests/test_softmax_op.py
+18
-6
python/paddle/fluid/tests/unittests/test_sum_op.py
python/paddle/fluid/tests/unittests/test_sum_op.py
+42
-4
未找到文件。
paddle/fluid/framework/threadpool.cc
浏览文件 @
25e070ec
...
@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) {
...
@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) {
ThreadPool
::~
ThreadPool
()
{
ThreadPool
::~
ThreadPool
()
{
{
{
// notify all threads to stop running
// notify all threads to stop running
std
::
lock_guard
<
std
::
mutex
>
l
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
l
(
mutex_
);
running_
=
false
;
running_
=
false
;
scheduled_
.
notify_all
();
}
}
scheduled_
.
notify_all
();
for
(
auto
&
t
:
threads_
)
{
for
(
auto
&
t
:
threads_
)
{
t
->
join
();
t
->
join
();
...
@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() {
...
@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() {
void
ThreadPool
::
TaskLoop
()
{
void
ThreadPool
::
TaskLoop
()
{
while
(
true
)
{
while
(
true
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
)
;
Task
task
;
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
scheduled_
.
wait
(
scheduled_
.
wait
(
lock
,
[
this
]
{
return
!
this
->
tasks_
.
empty
()
||
!
this
->
running_
;
});
lock
,
[
this
]
{
return
!
this
->
tasks_
.
empty
()
||
!
this
->
running_
;
});
if
(
!
running_
||
tasks_
.
empty
())
{
if
(
!
running_
&&
tasks_
.
empty
())
{
return
;
return
;
}
}
if
(
tasks_
.
empty
())
{
PADDLE_THROW
(
"This thread has no task to Run"
);
}
// pop a task from the task queue
// pop a task from the task queue
auto
task
=
std
::
move
(
tasks_
.
front
());
task
=
std
::
move
(
tasks_
.
front
());
tasks_
.
pop
();
tasks_
.
pop
();
lock
.
unlock
();
}
// run the task
// run the task
task
();
task
();
...
...
paddle/fluid/framework/threadpool.h
浏览文件 @
25e070ec
...
@@ -69,7 +69,6 @@ class ThreadPool {
...
@@ -69,7 +69,6 @@ class ThreadPool {
template
<
typename
Callback
>
template
<
typename
Callback
>
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
RunAndGetException
(
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
RunAndGetException
(
Callback
fn
)
{
Callback
fn
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
Task
task
([
fn
]()
->
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
{
Task
task
([
fn
]()
->
std
::
unique_ptr
<
platform
::
EnforceNotMet
>
{
try
{
try
{
fn
();
fn
();
...
@@ -84,7 +83,13 @@ class ThreadPool {
...
@@ -84,7 +83,13 @@ class ThreadPool {
return
nullptr
;
return
nullptr
;
});
});
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
f
=
task
.
get_future
();
std
::
future
<
std
::
unique_ptr
<
platform
::
EnforceNotMet
>>
f
=
task
.
get_future
();
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
running_
)
{
PADDLE_THROW
(
"enqueue on stopped ThreadPool"
);
}
tasks_
.
push
(
std
::
move
(
task
));
tasks_
.
push
(
std
::
move
(
task
));
}
scheduled_
.
notify_one
();
scheduled_
.
notify_one
();
return
f
;
return
f
;
}
}
...
...
paddle/fluid/operators/activation_op.cu
浏览文件 @
25e070ec
...
@@ -26,6 +26,8 @@ namespace plat = paddle::platform;
...
@@ -26,6 +26,8 @@ namespace plat = paddle::platform;
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>);
ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR
(
REGISTER_ACTIVATION_CUDA_KERNEL
);
FOR_EACH_KERNEL_FUNCTOR
(
REGISTER_ACTIVATION_CUDA_KERNEL
);
paddle/fluid/operators/activation_op.h
浏览文件 @
25e070ec
...
@@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
...
@@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
const
Out
out_conj
=
Eigen
::
numext
::
conj
(
out
);
dx
.
device
(
d
)
=
static_cast
<
T
>
(
0.5
)
*
dout
/
out
;
dx
.
device
(
d
)
=
static_cast
<
T
>
(
0.5
)
*
dout
/
out_conj
;
}
}
};
};
...
@@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
...
@@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename
dX
>
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
factor
)
*
dx
.
device
(
d
)
=
dout
*
static_cast
<
T
>
(
factor
)
*
x
.
pow
(
static_cast
<
T
>
(
factor
-
static_cast
<
T
>
(
1
)
));
x
.
pow
(
static_cast
<
T
>
(
factor
)
-
static_cast
<
T
>
(
1
));
}
}
};
};
...
...
paddle/fluid/operators/batch_norm_op.cu.cc
浏览文件 @
25e070ec
...
@@ -219,8 +219,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -219,8 +219,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto
*
d_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
d_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_scale
->
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
());
d_bias
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_bias
->
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
if
((
N
*
H
*
W
*
D
)
==
1
)
{
if
((
N
*
H
*
W
*
D
)
==
1
)
{
...
@@ -272,8 +272,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -272,8 +272,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
void
*
saved_mean_data
=
saved_mean
->
template
data
<
T
>();
const
void
*
saved_mean_data
=
const
void
*
saved_var_data
=
saved_var
->
template
data
<
T
>();
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
();
const
void
*
saved_var_data
=
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
...
@@ -281,10 +283,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -281,10 +283,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
d_y
->
template
data
<
T
>(),
data_desc_
,
data_desc_
,
d_y
->
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
T
>(),
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
epsilon
,
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())
,
saved_mean_data
,
saved_var_data
));
epsilon
,
saved_mean_data
,
saved_var_data
));
// clean when exit.
// clean when exit.
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
...
@@ -304,4 +306,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -304,4 +306,5 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
ops
::
BatchNormKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
batch_norm_grad
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
BatchNormGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
25e070ec
...
@@ -143,9 +143,11 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -143,9 +143,11 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc
,
CUDNN_TENSOR_OP_MATH
));
cudnn_conv_desc
,
CUDNN_TENSOR_OP_MATH
));
// Currently tensor core is only enabled using this algo
// Currently tensor core is only enabled using this algo
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
;
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
;
VLOG
(
5
)
<<
"use cudnn_tensor_op_math"
;
}
else
{
}
else
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnSetConvolutionMathType
(
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
cudnn_conv_desc
,
CUDNN_DEFAULT_MATH
));
VLOG
(
5
)
<<
"NOT use cudnn_tensor_op_math"
;
}
}
#endif
#endif
...
@@ -361,7 +363,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
...
@@ -361,7 +363,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle
::
operators
::
CUDNNConvOpKernel
<
plat
::
float16
>
);
paddle
::
operators
::
CUDNNConvOpKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
conv2d_grad
,
CUDNN
,
plat
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
conv2d_grad
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNConvGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNConvGradOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNConvGradOpKernel
<
double
>
);
paddle
::
operators
::
CUDNNConvGradOpKernel
<
double
>
,
paddle
::
operators
::
CUDNNConvGradOpKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
conv3d
,
CUDNN
,
plat
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
conv3d
,
CUDNN
,
plat
::
CUDAPlace
,
paddle
::
operators
::
CUDNNConvOpKernel
<
float
>
,
paddle
::
operators
::
CUDNNConvOpKernel
<
float
>
,
...
...
paddle/fluid/operators/cross_entropy_op.cu
浏览文件 @
25e070ec
...
@@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h"
#include "paddle/fluid/operators/cross_entropy_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
plat
=
paddle
::
platform
;
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
using
CUDACtx
=
paddle
::
platform
::
CUDADeviceContext
;
using
CUDACtx
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
cross_entropy
,
REGISTER_OP_CUDA_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
float
>
,
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
float
>
,
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
double
>
);
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
double
>
,
REGISTER_OP_CUDA_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyOpKernel
<
CUDACtx
,
plat
::
float16
>
);
ops
::
CrossEntropyGradientOpKernel
<
CUDACtx
,
float
>
,
ops
::
CrossEntropyGradientOpKernel
<
CUDACtx
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOpKernel
<
CUDACtx
,
float
>
,
ops
::
CrossEntropyGradientOpKernel
<
CUDACtx
,
double
>
,
ops
::
CrossEntropyGradientOpKernel
<
CUDACtx
,
plat
::
float16
>
);
paddle/fluid/operators/elementwise_add_op.cu
浏览文件 @
25e070ec
...
@@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
);
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
,
ops
::
ElementwiseAddGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
25e070ec
...
@@ -365,7 +365,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
...
@@ -365,7 +365,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
int
j
=
blockIdx
.
x
;
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
T
val
=
0
;
T
val
(
0
)
;
do
{
do
{
int
x_offset
=
i
*
w
+
j
;
int
x_offset
=
i
*
w
+
j
;
...
@@ -433,7 +433,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
...
@@ -433,7 +433,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
j
=
blockIdx
.
x
;
int
j
=
blockIdx
.
x
;
T
val
=
0
;
T
val
(
0
)
;
int
ttid
=
tid
;
int
ttid
=
tid
;
while
(
true
)
{
while
(
true
)
{
...
...
paddle/fluid/operators/math/cross_entropy.cu
浏览文件 @
25e070ec
...
@@ -21,6 +21,16 @@ namespace operators {
...
@@ -21,6 +21,16 @@ namespace operators {
namespace
math
{
namespace
math
{
namespace
{
namespace
{
__device__
__forceinline__
float
real_log
(
float
x
)
{
return
logf
(
x
);
}
__device__
__forceinline__
double
real_log
(
double
x
)
{
return
log
(
x
);
}
__device__
__forceinline__
platform
::
float16
real_log
(
const
platform
::
float16
&
val
)
{
return
static_cast
<
platform
::
float16
>
(
hlog
(
static_cast
<
half
>
(
val
)));
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int64_t
*
label
,
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int64_t
*
label
,
const
int
N
,
const
int
D
,
const
int
N
,
const
int
D
,
...
@@ -29,8 +39,8 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
...
@@ -29,8 +39,8 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
||
label
[
i
]
==
ignore_index
);
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
||
label
[
i
]
==
ignore_index
);
Y
[
i
]
=
ignore_index
==
label
[
i
]
Y
[
i
]
=
ignore_index
==
label
[
i
]
?
0
?
static_cast
<
T
>
(
0
)
:
-
math
::
TolerableValue
<
T
>
()(
log
(
X
[
i
*
D
+
label
[
i
]]));
:
-
math
::
TolerableValue
<
T
>
()(
real_
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
}
}
...
@@ -38,12 +48,12 @@ template <typename T>
...
@@ -38,12 +48,12 @@ template <typename T>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
T
val
=
0
;
T
val
(
0
)
;
int
idx
=
blockIdx
.
x
*
class_num
+
tid
;
int
idx
=
blockIdx
.
x
*
class_num
+
tid
;
int
end
=
blockIdx
.
x
*
class_num
+
class_num
;
int
end
=
blockIdx
.
x
*
class_num
+
class_num
;
for
(;
idx
<
end
;
idx
+=
blockDim
.
x
)
{
for
(;
idx
<
end
;
idx
+=
blockDim
.
x
)
{
val
+=
math
::
TolerableValue
<
T
>
()(
std
::
log
(
X
[
idx
]))
*
label
[
idx
];
val
+=
math
::
TolerableValue
<
T
>
()(
real_
log
(
X
[
idx
]))
*
label
[
idx
];
}
}
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
blockDim
.
x
);
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
blockDim
.
x
);
...
@@ -53,8 +63,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
...
@@ -53,8 +63,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
}
}
}
// namespace
}
// namespace
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
...
@@ -89,6 +97,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
...
@@ -89,6 +97,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
template
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/math/cross_entropy.h
浏览文件 @
25e070ec
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <limits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/hostdevice.h"
...
@@ -33,6 +34,26 @@ struct TolerableValue {
...
@@ -33,6 +34,26 @@ struct TolerableValue {
}
}
};
};
// NOTE(dzh): float16 value clip behave different.
// 1. Our ValueClipping has a hardcore threshold 1e20
// for float number. 1e20 will resulting in overflow in float16.
// 2. float16 should expose the the real number overflow to python.
// because mixed-training depends the inf/nan value to determine
// if the scale value will be adjusted.
// Also. In standard implementation of cross entropy, other
// framework not has the ValueClipping.
template
<
>
struct
TolerableValue
<
platform
::
float16
>
{
HOSTDEVICE
platform
::
float16
operator
()(
const
platform
::
float16
&
x
)
const
{
if
(
platform
::
isfinite
(
x
))
return
x
;
else
if
(
x
>
static_cast
<
platform
::
float16
>
(
0
))
return
std
::
numeric_limits
<
platform
::
float16
>::
max
();
else
return
std
::
numeric_limits
<
platform
::
float16
>::
min
();
}
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
CrossEntropyFunctor
{
class
CrossEntropyFunctor
{
public:
public:
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
25e070ec
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -118,7 +119,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
...
@@ -118,7 +119,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
auto
*
out_data
=
output
->
data
<
T
>
();
auto
*
out_data
=
output
->
data
<
T
>
();
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
functor
;
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
functor
;
functor
(
context
,
output
,
0.0
);
functor
(
context
,
output
,
static_cast
<
T
>
(
0
)
);
const
int
block_size
=
256
;
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
threads
(
block_size
,
1
);
...
@@ -136,6 +137,9 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
...
@@ -136,6 +137,9 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
struct
SelectedRowsAddTensor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
<
typename
T
>
template
<
typename
T
>
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
T
>
{
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
T
>
{
...
@@ -175,6 +179,8 @@ template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
...
@@ -175,6 +179,8 @@ template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SelectedRowsAddTo
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
namespace
{
namespace
{
template
<
typename
T
,
int
block_size
>
template
<
typename
T
,
int
block_size
>
...
@@ -227,6 +233,8 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
...
@@ -227,6 +233,8 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
SelectedRowsAddToTensor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
namespace
scatter
{
namespace
scatter
{
...
@@ -287,7 +295,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -287,7 +295,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
context
.
GetPlace
());
context
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
constant_functor
;
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
constant_functor
;
constant_functor
(
context
,
out
.
mutable_value
(),
0.0
);
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0
)
);
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
auto
*
input_data
=
input
.
value
().
data
<
T
>
();
auto
*
input_data
=
input
.
value
().
data
<
T
>
();
...
@@ -347,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
...
@@ -347,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
context
.
GetPlace
());
context
.
GetPlace
());
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
constant_functor
;
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
constant_functor
;
constant_functor
(
context
,
out
.
mutable_value
(),
0.0
);
constant_functor
(
context
,
out
.
mutable_value
(),
static_cast
<
T
>
(
0
)
);
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
auto
*
out_data
=
out
.
mutable_value
()
->
data
<
T
>
();
...
@@ -374,6 +382,7 @@ template struct MergeAdd<platform::CUDADeviceContext, float>;
...
@@ -374,6 +382,7 @@ template struct MergeAdd<platform::CUDADeviceContext, float>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
struct
MergeAdd
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
<
typename
T
,
int
block_size
>
template
<
typename
T
,
int
block_size
>
__global__
void
UpdateToTensorKernel
(
const
T
*
selected_rows
,
__global__
void
UpdateToTensorKernel
(
const
T
*
selected_rows
,
...
...
paddle/fluid/operators/math/softmax.cu
浏览文件 @
25e070ec
...
@@ -96,12 +96,15 @@ template class SoftmaxCUDNNFunctor<float>;
...
@@ -96,12 +96,15 @@ template class SoftmaxCUDNNFunctor<float>;
template
class
SoftmaxCUDNNFunctor
<
double
>;
template
class
SoftmaxCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
float
>;
template
class
SoftmaxGradCUDNNFunctor
<
float
>;
template
class
SoftmaxGradCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
double
>;
template
class
SoftmaxGradCUDNNFunctor
<
platform
::
float16
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
SoftmaxFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
SoftmaxGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
SoftmaxGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
SoftmaxGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
SoftmaxGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
SoftmaxGradFunctor
<
platform
::
CUDADeviceContext
,
platform
::
float16
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/mean_op.cu
浏览文件 @
25e070ec
...
@@ -15,11 +15,15 @@ limitations under the License. */
...
@@ -15,11 +15,15 @@ limitations under the License. */
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
mean
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
mean
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MeanKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
mean_grad
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
mean_grad
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
MeanGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/mean_op.h
浏览文件 @
25e070ec
...
@@ -55,8 +55,7 @@ class MeanGradKernel : public framework::OpKernel<T> {
...
@@ -55,8 +55,7 @@ class MeanGradKernel : public framework::OpKernel<T> {
IG
->
mutable_data
<
T
>
(
context
.
GetPlace
());
IG
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
ig_size
=
static_cast
<
T
>
(
IG
->
numel
());
T
ig_size
=
static_cast
<
T
>
(
IG
->
numel
());
Eigen
::
DSizes
<
int
,
1
>
bcast
(
ig_size
);
Eigen
::
DSizes
<
int
,
1
>
bcast
(
static_cast
<
int
>
(
ig_size
));
EigenVector
<
T
>::
Flatten
(
*
IG
).
device
(
EigenVector
<
T
>::
Flatten
(
*
IG
).
device
(
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
())
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
())
=
(
EigenVector
<
T
>::
From
(
*
OG
)
/
ig_size
).
broadcast
(
bcast
);
(
EigenVector
<
T
>::
From
(
*
OG
)
/
ig_size
).
broadcast
(
bcast
);
...
...
paddle/fluid/operators/mul_op.cu.cc
浏览文件 @
25e070ec
...
@@ -20,6 +20,7 @@ namespace plat = paddle::platform;
...
@@ -20,6 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL
(
mul
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
mul
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
ops
::
MulKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
mul_grad
,
REGISTER_OP_CUDA_KERNEL
(
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
mul_grad
,
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
MulGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/pool_cudnn_op.cu.cc
浏览文件 @
25e070ec
...
@@ -178,7 +178,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
...
@@ -178,7 +178,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
ops
::
PoolCUDNNOpKernel
<
plat
::
float16
>
);
ops
::
PoolCUDNNOpKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
pool2d_grad
,
CUDNN
,
plat
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
pool2d_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
PoolCUDNNGradOpKernel
<
float
>
,
ops
::
PoolCUDNNGradOpKernel
<
float
>
,
ops
::
PoolCUDNNGradOpKernel
<
double
>
);
ops
::
PoolCUDNNGradOpKernel
<
double
>
,
ops
::
PoolCUDNNGradOpKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
pool3d
,
CUDNN
,
plat
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
pool3d
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
PoolCUDNNOpKernel
<
float
>
,
ops
::
PoolCUDNNOpKernel
<
float
>
,
...
...
paddle/fluid/operators/scale_op.cu
浏览文件 @
25e070ec
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include "paddle/fluid/operators/scale_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
scale
,
scale
,
...
@@ -20,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -20,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL(
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
int64_t
>
,
paddle
::
operators
::
ScaleKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/softmax_cudnn_op.cu.cc
浏览文件 @
25e070ec
...
@@ -80,4 +80,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
...
@@ -80,4 +80,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
double
>
);
ops
::
SoftmaxGradCUDNNKernel
<
double
>
,
ops
::
SoftmaxGradCUDNNKernel
<
plat
::
float16
>
);
paddle/fluid/operators/softmax_op.cu.cc
浏览文件 @
25e070ec
...
@@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
SoftmaxKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
ops
::
SoftmaxKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
softmax_grad
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
softmax_grad
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
double
>
,
ops
::
SoftmaxGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/sum_op.cu
浏览文件 @
25e070ec
...
@@ -11,10 +11,13 @@ limitations under the License. */
...
@@ -11,10 +11,13 @@ limitations under the License. */
#define EIGEN_USE_GPU
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
sum
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
sum
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
,
ops
::
SumKernel
<
paddle
::
platform
::
CUDADeviceContext
,
plat
::
float16
>
);
paddle/fluid/operators/sum_op.h
浏览文件 @
25e070ec
...
@@ -61,7 +61,7 @@ class SumKernel : public framework::OpKernel<T> {
...
@@ -61,7 +61,7 @@ class SumKernel : public framework::OpKernel<T> {
if
(
start
!=
2
)
{
if
(
start
!=
2
)
{
math
::
SetConstant
<
DeviceContext
,
T
>
constant_functor
;
math
::
SetConstant
<
DeviceContext
,
T
>
constant_functor
;
constant_functor
(
context
.
template
device_context
<
DeviceContext
>(),
constant_functor
(
context
.
template
device_context
<
DeviceContext
>(),
out
,
0.0
);
out
,
static_cast
<
T
>
(
0
)
);
}
}
}
}
...
...
python/paddle/fluid/io.py
浏览文件 @
25e070ec
...
@@ -65,7 +65,7 @@ def is_persistable(var):
...
@@ -65,7 +65,7 @@ def is_persistable(var):
Examples:
Examples:
.. code-block:: python
.. code-block:: python
param = fluid.default_main_program().global_block().var('fc.
w
')
param = fluid.default_main_program().global_block().var('fc.
b
')
res = fluid.io.is_persistable(param)
res = fluid.io.is_persistable(param)
"""
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
...
@@ -625,8 +625,13 @@ def save_inference_model(dirname,
...
@@ -625,8 +625,13 @@ def save_inference_model(dirname,
main_program
.
_distributed_lookup_table
,
main_program
.
_distributed_lookup_table
,
main_program
.
_endpoints
)
main_program
.
_endpoints
)
if
not
os
.
path
.
isdir
(
dirname
):
# when a pserver and a trainer running on the same machine, mkdir may conflict
try
:
os
.
makedirs
(
dirname
)
os
.
makedirs
(
dirname
)
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
if
model_filename
is
not
None
:
if
model_filename
is
not
None
:
model_basename
=
os
.
path
.
basename
(
model_filename
)
model_basename
=
os
.
path
.
basename
(
model_filename
)
else
:
else
:
...
...
python/paddle/fluid/recordio_writer.py
浏览文件 @
25e070ec
...
@@ -41,9 +41,6 @@ def convert_reader_to_recordio_file(
...
@@ -41,9 +41,6 @@ def convert_reader_to_recordio_file(
"""
"""
Convert a Python Reader to a recordio file.
Convert a Python Reader to a recordio file.
Please see :ref:`api_guide_python_reader` and :ref:`api_guide_reader_op` for
details.
Examples:
Examples:
>>> import paddle.fluid as fluid
>>> import paddle.fluid as fluid
...
...
python/paddle/fluid/tests/unittests/op_test.py
浏览文件 @
25e070ec
...
@@ -54,14 +54,6 @@ def get_numeric_gradient(place,
...
@@ -54,14 +54,6 @@ def get_numeric_gradient(place,
def
product
(
dim
):
def
product
(
dim
):
return
six
.
moves
.
reduce
(
lambda
a
,
b
:
a
*
b
,
dim
,
1
)
return
six
.
moves
.
reduce
(
lambda
a
,
b
:
a
*
b
,
dim
,
1
)
def
get_output
():
sum
=
[]
op
.
run
(
scope
,
place
)
for
output_name
in
output_names
:
sum
.
append
(
np
.
array
(
scope
.
find_var
(
output_name
).
get_tensor
()).
mean
())
return
np
.
array
(
sum
).
sum
()
/
len
(
output_names
)
tensor_to_check
=
scope
.
find_var
(
input_to_check
).
get_tensor
()
tensor_to_check
=
scope
.
find_var
(
input_to_check
).
get_tensor
()
tensor_size
=
product
(
tensor_to_check
.
shape
())
tensor_size
=
product
(
tensor_to_check
.
shape
())
tensor_to_check_dtype
=
tensor_to_check
.
_dtype
()
tensor_to_check_dtype
=
tensor_to_check
.
_dtype
()
...
@@ -77,6 +69,15 @@ def get_numeric_gradient(place,
...
@@ -77,6 +69,15 @@ def get_numeric_gradient(place,
raise
ValueError
(
"Not supported data type "
+
str
(
raise
ValueError
(
"Not supported data type "
+
str
(
tensor_to_check_dtype
))
tensor_to_check_dtype
))
def
get_output
():
sum
=
[]
op
.
run
(
scope
,
place
)
for
output_name
in
output_names
:
sum
.
append
(
np
.
array
(
scope
.
find_var
(
output_name
).
get_tensor
()).
astype
(
tensor_to_check_dtype
).
mean
())
return
tensor_to_check_dtype
(
np
.
array
(
sum
).
sum
()
/
len
(
output_names
))
gradient_flat
=
np
.
zeros
(
shape
=
(
tensor_size
,
),
dtype
=
tensor_to_check_dtype
)
gradient_flat
=
np
.
zeros
(
shape
=
(
tensor_size
,
),
dtype
=
tensor_to_check_dtype
)
def
__get_elem__
(
tensor
,
i
):
def
__get_elem__
(
tensor
,
i
):
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
25e070ec
...
@@ -21,7 +21,7 @@ from op_test import OpTest
...
@@ -21,7 +21,7 @@ from op_test import OpTest
from
scipy.special
import
expit
from
scipy.special
import
expit
class
Test
Exp
(
OpTest
):
class
Test
Activation
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"exp"
self
.
op_type
=
"exp"
self
.
dtype
=
np
.
float32
self
.
dtype
=
np
.
float32
...
@@ -42,24 +42,12 @@ class TestExp(OpTest):
...
@@ -42,24 +42,12 @@ class TestExp(OpTest):
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
def
init_dtype
(
self
):
pass
self
.
dtype
=
np
.
float32
class
TestFP16Exp
(
TestExp
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSigmoid
(
OpTest
):
class
TestSigmoid
(
TestActivation
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"sigmoid"
self
.
op_type
=
"sigmoid"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -68,33 +56,15 @@ class TestSigmoid(OpTest):
...
@@ -68,33 +56,15 @@ class TestSigmoid(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
def
init_dtype
(
self
):
pass
class
TestFP16Sigmoid
(
TestSigmoid
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
class
TestLogSigmoid
(
TestActivation
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestLogSigmoid
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"logsigmoid"
self
.
op_type
=
"logsigmoid"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -103,33 +73,15 @@ class TestLogSigmoid(OpTest):
...
@@ -103,33 +73,15 @@ class TestLogSigmoid(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
def
init_dtype
(
self
):
pass
class
TestTanh
(
TestActivation
):
class
TestFP16LogSigmoid
(
TestLogSigmoid
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestTanh
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"tanh"
self
.
op_type
=
"tanh"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -138,33 +90,15 @@ class TestTanh(OpTest):
...
@@ -138,33 +90,15 @@ class TestTanh(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Tanh
(
TestTanh
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestTanhShrink
(
OpTest
):
class
TestTanhShrink
(
TestActivation
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"tanh_shrink"
self
.
op_type
=
"tanh_shrink"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
17
]).
astype
(
self
.
dtype
)
...
@@ -173,33 +107,15 @@ class TestTanhShrink(OpTest):
...
@@ -173,33 +107,15 @@ class TestTanhShrink(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
def
init_dtype
(
self
):
pass
class
TestHardShrink
(
TestActivation
):
class
TestFP16TanhShrink
(
TestTanhShrink
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestHardShrink
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"hard_shrink"
self
.
op_type
=
"hard_shrink"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
threshold
=
0.5
threshold
=
0.5
...
@@ -211,33 +127,15 @@ class TestHardShrink(OpTest):
...
@@ -211,33 +127,15 @@ class TestHardShrink(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.005
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.005
)
def
init_dtype
(
self
):
pass
class
TestFP16HardShrink
(
TestHardShrink
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSoftShrink
(
OpTest
):
class
TestSoftShrink
(
TestActivation
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"softshrink"
self
.
op_type
=
"softshrink"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
lambda_val
=
0.1
lambda_val
=
0.1
...
@@ -250,33 +148,15 @@ class TestSoftShrink(OpTest):
...
@@ -250,33 +148,15 @@ class TestSoftShrink(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16SoftShrink
(
TestSoftShrink
):
class
TestSqrt
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSqrt
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"sqrt"
self
.
op_type
=
"sqrt"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -285,33 +165,15 @@ class TestSqrt(OpTest):
...
@@ -285,33 +165,15 @@ class TestSqrt(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Sqrt
(
TestSqrt
):
class
TestAbs
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestAbs
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"abs"
self
.
op_type
=
"abs"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -325,33 +187,15 @@ class TestAbs(OpTest):
...
@@ -325,33 +187,15 @@ class TestAbs(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestCeil
(
TestActivation
):
class
TestFP16Abs
(
TestAbs
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestCeil
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"ceil"
self
.
op_type
=
"ceil"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -360,30 +204,14 @@ class TestCeil(OpTest):
...
@@ -360,30 +204,14 @@ class TestCeil(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
# The same reason with TestFloor
# The same reason with TestFloor
def
test_check_grad
(
self
):
def
init_dtype
(
self
):
pass
pass
class
TestFP16Ceil
(
TestCeil
):
class
TestFloor
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestFloor
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"floor"
self
.
op_type
=
"floor"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -392,31 +220,16 @@ class TestFloor(OpTest):
...
@@ -392,31 +220,16 @@ class TestFloor(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
# the gradient on floor, ceil, round is undefined.
# the gradient on floor, ceil, round is undefined.
# we return zero as gradient, but the numpy return nan
# we return zero as gradient, but the numpy return nan
# The same reason with TestFloor
def
init_dtype
(
self
):
def
test_check_grad
(
self
):
pass
pass
class
TestFP16Floor
(
TestFloor
):
class
TestCos
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestCos
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cos"
self
.
op_type
=
"cos"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -425,33 +238,15 @@ class TestCos(OpTest):
...
@@ -425,33 +238,15 @@ class TestCos(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Cos
(
TestCos
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSin
(
TestActivation
):
class
TestSin
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"sin"
self
.
op_type
=
"sin"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -460,33 +255,15 @@ class TestSin(OpTest):
...
@@ -460,33 +255,15 @@ class TestSin(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Sin
(
TestSin
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
class
TestRound
(
TestActivation
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestRound
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"round"
self
.
op_type
=
"round"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -495,28 +272,13 @@ class TestRound(OpTest):
...
@@ -495,28 +272,13 @@ class TestRound(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
def
test_check_grad
(
self
):
self
.
check_output
()
def
init_dtype
(
self
):
pass
pass
class
TestFP16Round
(
TestRound
):
class
TestRelu
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestRelu
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"relu"
self
.
op_type
=
"relu"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -527,33 +289,15 @@ class TestRelu(OpTest):
...
@@ -527,33 +289,15 @@ class TestRelu(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Relu
(
TestRelu
):
class
TestBRelu
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestBRelu
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"brelu"
self
.
op_type
=
"brelu"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -570,33 +314,15 @@ class TestBRelu(OpTest):
...
@@ -570,33 +314,15 @@ class TestBRelu(OpTest):
self
.
attrs
=
{
't_min'
:
t_min
,
't_max'
:
t_max
}
self
.
attrs
=
{
't_min'
:
t_min
,
't_max'
:
t_max
}
self
.
outputs
=
{
'Out'
:
t
}
self
.
outputs
=
{
'Out'
:
t
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
def
init_dtype
(
self
):
pass
class
TestFP16BRelu
(
TestBRelu
):
class
TestRelu6
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestRelu6
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"relu6"
self
.
op_type
=
"relu6"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
10
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
4
,
10
]).
astype
(
self
.
dtype
)
...
@@ -610,33 +336,15 @@ class TestRelu6(OpTest):
...
@@ -610,33 +336,15 @@ class TestRelu6(OpTest):
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
def
init_dtype
(
self
):
pass
class
TestFP16Relu6
(
TestRelu6
):
class
TestSoftRelu
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSoftRelu
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"soft_relu"
self
.
op_type
=
"soft_relu"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
3
,
3
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
3
,
3
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -653,33 +361,15 @@ class TestSoftRelu(OpTest):
...
@@ -653,33 +361,15 @@ class TestSoftRelu(OpTest):
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
def
init_dtype
(
self
):
pass
class
TestFP16SoftRelu
(
TestSoftRelu
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
class
TestELU
(
TestActivation
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestELU
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"elu"
self
.
op_type
=
"elu"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
3
,
3
,
[
4
,
4
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
3
,
3
,
[
4
,
4
]).
astype
(
self
.
dtype
)
...
@@ -691,33 +381,15 @@ class TestELU(OpTest):
...
@@ -691,33 +381,15 @@ class TestELU(OpTest):
self
.
attrs
=
{
'alpha'
:
alpha
}
self
.
attrs
=
{
'alpha'
:
alpha
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
def
init_dtype
(
self
):
pass
class
TestReciprocal
(
TestActivation
):
class
TestFP16ELU
(
TestELU
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestReciprocal
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"reciprocal"
self
.
op_type
=
"reciprocal"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -726,33 +398,15 @@ class TestReciprocal(OpTest):
...
@@ -726,33 +398,15 @@ class TestReciprocal(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
def
init_dtype
(
self
):
pass
class
TestFP16Reciprocal
(
TestReciprocal
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestLog
(
TestActivation
):
class
TestLog
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"log"
self
.
op_type
=
"log"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -761,33 +415,15 @@ class TestLog(OpTest):
...
@@ -761,33 +415,15 @@ class TestLog(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Log
(
TestLog
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSquare
(
OpTest
):
class
TestSquare
(
TestActivation
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"square"
self
.
op_type
=
"square"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -796,33 +432,15 @@ class TestSquare(OpTest):
...
@@ -796,33 +432,15 @@ class TestSquare(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Square
(
TestSquare
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestPow
(
TestActivation
):
class
TestPow
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"pow"
self
.
op_type
=
"pow"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -832,33 +450,15 @@ class TestPow(OpTest):
...
@@ -832,33 +450,15 @@ class TestPow(OpTest):
self
.
attrs
=
{
'factor'
:
3.0
}
self
.
attrs
=
{
'factor'
:
3.0
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.02
)
def
init_dtype
(
self
):
pass
class
TestFP16Pow
(
TestPow
):
class
TestSTanh
(
TestActivation
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
5e-2
)
class
TestSTanh
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"stanh"
self
.
op_type
=
"stanh"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -870,34 +470,17 @@ class TestSTanh(OpTest):
...
@@ -870,34 +470,17 @@ class TestSTanh(OpTest):
self
.
attrs
=
{
'scale_a'
:
scale_a
,
'scale_b'
:
scale_b
}
self
.
attrs
=
{
'scale_a'
:
scale_a
,
'scale_b'
:
scale_b
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16STanh
(
TestSTanh
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSoftplus
(
TestActivation
):
class
TestSoftplus
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"softplus"
self
.
op_type
=
"softplus"
self
.
dtype
=
np
.
float64
self
.
init_dtype
()
self
.
init_dtype
()
self
.
dtype
=
np
.
float64
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
out
=
np
.
log
(
1
+
np
.
exp
(
x
))
out
=
np
.
log
(
1
+
np
.
exp
(
x
))
...
@@ -905,33 +488,15 @@ class TestSoftplus(OpTest):
...
@@ -905,33 +488,15 @@ class TestSoftplus(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestFP16Softplus
(
TestSoftplus
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
class
TestSoftsign
(
TestActivation
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSoftsign
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"softsign"
self
.
op_type
=
"softsign"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -940,33 +505,15 @@ class TestSoftsign(OpTest):
...
@@ -940,33 +505,15 @@ class TestSoftsign(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.007
)
def
init_dtype
(
self
):
pass
class
TestThresholdedRelu
(
TestActivation
):
class
TestFP16Softsign
(
TestSoftsign
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestThresholdedRelu
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"thresholded_relu"
self
.
op_type
=
"thresholded_relu"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
threshold
=
0.25
threshold
=
0.25
...
@@ -981,33 +528,15 @@ class TestThresholdedRelu(OpTest):
...
@@ -981,33 +528,15 @@ class TestThresholdedRelu(OpTest):
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
self
.
relative_error
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
self
.
relative_error
)
def
init_dtype
(
self
):
pass
class
TestFP16ThresholdedRelu
(
TestThresholdedRelu
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestHardSigmoid
(
TestActivation
):
class
TestHardSigmoid
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"hard_sigmoid"
self
.
op_type
=
"hard_sigmoid"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
self
.
relative_error
=
0.002
self
.
relative_error
=
0.002
...
@@ -1030,33 +559,15 @@ class TestHardSigmoid(OpTest):
...
@@ -1030,33 +559,15 @@ class TestHardSigmoid(OpTest):
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
X
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
X
)}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.002
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.002
)
def
init_dtype
(
self
):
pass
class
TestFP16HardSigmoid
(
TestHardSigmoid
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestSwish
(
OpTest
):
class
TestSwish
(
TestActivation
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"swish"
self
.
op_type
=
"swish"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
init_dtype
()
X
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
X
=
np
.
random
.
uniform
(
0.1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
...
@@ -1067,28 +578,70 @@ class TestSwish(OpTest):
...
@@ -1067,28 +578,70 @@ class TestSwish(OpTest):
self
.
attrs
=
{
'beta'
:
beta
}
self
.
attrs
=
{
'beta'
:
beta
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
dtype
==
np
.
float16
:
return
return
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.008
)
def
init_dtype
(
self
):
pass
class
TestFP16Swish
(
TestSwish
):
#------------------ Test Fp16 ----------------------
def
create_test_act_fp16_class
(
parent
,
atol
=
1e-3
,
grad_check
=
True
,
grad_atol
=
0.80
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestActFp16
(
parent
):
def
init_dtype
(
self
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
support_fp16
=
core
.
is_float16_supported
(
place
)
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
if
support_fp16
:
self
.
check_output_with_place
(
place
,
atol
=
atol
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
support_fp16
=
core
.
is_float16_supported
(
place
)
if
support_fp16
and
grad_check
:
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
max_relative_error
=
grad_atol
)
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"fp16"
)
TestActFp16
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestActFp16
create_test_act_fp16_class
(
TestActivation
)
create_test_act_fp16_class
(
TestSigmoid
)
create_test_act_fp16_class
(
TestLogSigmoid
)
create_test_act_fp16_class
(
TestTanh
)
create_test_act_fp16_class
(
TestTanhShrink
)
create_test_act_fp16_class
(
TestHardShrink
)
create_test_act_fp16_class
(
TestSoftShrink
)
create_test_act_fp16_class
(
TestSqrt
)
create_test_act_fp16_class
(
TestAbs
)
create_test_act_fp16_class
(
TestCeil
,
grad_check
=
False
)
create_test_act_fp16_class
(
TestFloor
,
grad_check
=
False
)
create_test_act_fp16_class
(
TestCos
,
grad_atol
=
0.85
)
create_test_act_fp16_class
(
TestSin
)
create_test_act_fp16_class
(
TestRound
,
grad_check
=
False
)
create_test_act_fp16_class
(
TestRelu
)
create_test_act_fp16_class
(
TestBRelu
)
create_test_act_fp16_class
(
TestRelu6
)
create_test_act_fp16_class
(
TestSoftRelu
)
create_test_act_fp16_class
(
TestELU
)
create_test_act_fp16_class
(
TestReciprocal
)
create_test_act_fp16_class
(
TestLog
)
create_test_act_fp16_class
(
TestSquare
)
create_test_act_fp16_class
(
TestPow
,
atol
=
5e-2
)
create_test_act_fp16_class
(
TestSTanh
,
grad_atol
=
0.9
)
create_test_act_fp16_class
(
TestSoftplus
)
create_test_act_fp16_class
(
TestSoftsign
)
create_test_act_fp16_class
(
TestThresholdedRelu
)
create_test_act_fp16_class
(
TestHardSigmoid
)
create_test_act_fp16_class
(
TestSwish
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_conv2d_op.py
浏览文件 @
25e070ec
...
@@ -223,46 +223,34 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
...
@@ -223,46 +223,34 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
#----------------Conv2dCUDNN----------------
#----------------Conv2dCUDNN----------------
class
TestCUDNN
(
TestConv2dOp
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNN
(
TestConv2dOp
):
def
create_test_cudnn_class
(
parent
,
cls_name
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNCase
(
parent
):
def
init_kernel_type
(
self
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
class
TestCUDNNWithPad
(
TestWithPad
):
cls_name
=
"{0}"
.
format
(
cls_name
)
def
init_kernel_type
(
self
):
TestCUDNNCase
.
__name__
=
cls_name
self
.
use_cudnn
=
Tru
e
globals
()[
cls_name
]
=
TestCUDNNCas
e
class
TestFP16CUDNNWithPad
(
TestWithPad
):
create_test_cudnn_class
(
TestConv2dOp
,
"TestPool2DCUDNNOp"
)
def
init_kernel_type
(
self
):
create_test_cudnn_class
(
TestWithPad
,
"TestPool2DCUDNNOpCase1"
)
self
.
use_cudnn
=
True
create_test_cudnn_class
(
TestWithStride
,
"TestPool2DCUDNNOpCase2"
)
self
.
dtype
=
np
.
float16
create_test_cudnn_class
(
TestWithGroup
,
"TestPool2DCUDNNOpCase3"
)
create_test_cudnn_class
(
TestWith1x1
,
"TestPool2DCUDNNOpCase4"
)
create_test_cudnn_class
(
TestWithInput1x1Filter1x1
,
"TestPool2DCUDNNOpCase4"
)
def
test_check_output
(
self
):
#----------------Conv2dCUDNN----------------
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
class
TestCUDNNWithStride
(
TestWithStride
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNWithStride
(
TestWithStride
):
def
create_test_cudnn_fp16_class
(
parent
,
cls_name
,
grad_check
=
True
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestConv2DCUDNNFp16
(
parent
):
def
init_kernel_type
(
self
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
self
.
dtype
=
np
.
float16
...
@@ -273,56 +261,43 @@ class TestFP16CUDNNWithStride(TestWithStride):
...
@@ -273,56 +261,43 @@ class TestFP16CUDNNWithStride(TestWithStride):
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
def
test_check_grad_no_filter
(
self
):
class
TestCUDNNWithGroup
(
TestWithGroup
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNWithGroup
(
TestWithGroup
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
)
and
grad_check
:
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
self
.
check_grad_with_place
(
place
,
[
'Input'
],
'Output'
,
class
TestCUDNNWith1x1
(
TestWith1x1
):
max_relative_error
=
0.02
,
def
init_kernel_type
(
self
):
no_grad_set
=
set
([
'Filter'
]))
self
.
use_cudnn
=
True
class
TestFP16CUDNNWith1x1
(
TestWith1x1
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
def
test_check_grad_no_input
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
)
and
grad_check
:
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
self
.
check_grad_with_place
(
place
,
[
'Filter'
],
'Output'
,
class
TestCUDNNWithInput1x1Filter1x1
(
TestWithInput1x1Filter1x1
):
max_relative_error
=
0.02
,
def
init_kernel_type
(
self
):
no_grad_set
=
set
([
'Input'
]))
self
.
use_cudnn
=
True
class
TestFP16CUDNNWithInput1x1Filter1x1
(
TestWithInput1x1Filter1x1
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
cls_name
=
"{0}"
.
format
(
cls_name
)
if
core
.
is_compiled_with_cuda
():
TestConv2DCUDNNFp16
.
__name__
=
cls_name
place
=
core
.
CUDAPlace
(
0
)
globals
()[
cls_name
]
=
TestConv2DCUDNNFp16
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
create_test_cudnn_fp16_class
(
TestConv2dOp
,
"TestPool2DCUDNNFp16Op"
,
grad_check
=
False
)
create_test_cudnn_fp16_class
(
TestWithPad
,
"TestPool2DCUDNNFp16OpCase1"
,
grad_check
=
False
)
create_test_cudnn_fp16_class
(
TestWithStride
,
"TestPool2DCUDNNFp16OpCase2"
,
grad_check
=
False
)
create_test_cudnn_fp16_class
(
TestWithGroup
,
"TestPool2DCUDNNFp16OpCase3"
,
grad_check
=
False
)
create_test_cudnn_fp16_class
(
TestWith1x1
,
"TestPool2DCUDNNFp16OpCase4"
,
grad_check
=
False
)
create_test_cudnn_fp16_class
(
TestWithInput1x1Filter1x1
,
"TestPool2DCUDNNFp16OpCase4"
,
grad_check
=
False
)
# -------TestDepthwiseConv
class
TestDepthwiseConv
(
TestConv2dOp
):
class
TestDepthwiseConv
(
TestConv2dOp
):
...
...
python/paddle/fluid/tests/unittests/test_cross_entropy_op.py
浏览文件 @
25e070ec
...
@@ -16,28 +16,58 @@ from __future__ import print_function
...
@@ -16,28 +16,58 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
,
randomize_probability
from
op_test
import
OpTest
,
randomize_probability
class
TestCrossEntropyOp
1
(
OpTest
):
class
TestCrossEntropyOp
(
OpTest
):
"""Test cross-entropy with discrete one-hot labels.
"""Test cross-entropy with discrete one-hot labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
30
self
.
soft_label
=
False
class_num
=
10
self
.
ignore_index
=
-
100
self
.
dtype
=
np
.
float64
self
.
batch_size
=
30
self
.
class_num
=
10
self
.
init_dtype_type
()
self
.
init_attr_type
()
self
.
init_bs_class_num
()
self
.
init_x
()
self
.
init_label
()
self
.
get_cross_entropy
()
self
.
inputs
=
{
"X"
:
self
.
x
,
"Label"
:
self
.
label
}
self
.
outputs
=
{
"Y"
:
self
.
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
self
.
soft_label
,
"ignore_index"
:
self
.
ignore_index
}
def
init_x
(
self
):
self
.
x
=
randomize_probability
(
self
.
batch_size
,
self
.
class_num
,
dtype
=
self
.
dtype
)
def
init_label
(
self
):
self
.
label
=
np
.
random
.
randint
(
0
,
self
.
class_num
,
(
self
.
batch_size
,
1
),
dtype
=
"int64"
)
def
get_cross_entropy
(
self
):
self
.
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
self
.
x
[
i
][
self
.
label
[
i
][
0
]])]
for
i
in
range
(
self
.
x
.
shape
[
0
])],
dtype
=
"float64"
)
X
=
randomize_probability
(
batch_size
,
class_num
,
dtype
=
'float64'
)
def
init_attr_type
(
self
):
pass
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int64"
)
def
init_dtype_type
(
self
):
cross_entropy
=
np
.
asmatrix
(
pass
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float64"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
def
init_bs_class_num
(
self
):
self
.
outputs
=
{
"Y"
:
cross_entropy
}
pass
self
.
attrs
=
{
"soft_label"
:
False
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -46,197 +76,231 @@ class TestCrossEntropyOp1(OpTest):
...
@@ -46,197 +76,231 @@ class TestCrossEntropyOp1(OpTest):
self
.
check_grad
([
"X"
],
"Y"
,
numeric_grad_delta
=
0.001
)
self
.
check_grad
([
"X"
],
"Y"
,
numeric_grad_delta
=
0.001
)
class
TestCrossEntropyOp2
(
OpTest
):
class
TestCrossEntropyOp2
(
TestCrossEntropyOp
):
"""Test cross-entropy with vectorized soft labels.
"""Test cross-entropy with vectorized soft labels.
"""
"""
def
setUp
(
self
):
def
init_label
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
label
=
np
.
random
.
uniform
(
batch_size
=
5
0.1
,
1.0
,
[
self
.
batch_size
,
self
.
class_num
]).
astype
(
self
.
dtype
)
class_num
=
37
self
.
label
/=
self
.
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
X
=
randomize_probability
(
batch_size
,
class_num
)
def
get_cross_entropy
(
self
):
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
cross_entropy
=
(
-
self
.
label
*
np
.
log
(
self
.
x
)).
sum
(
[
batch_size
,
class_num
]).
astype
(
"float32"
)
axis
=
1
,
keepdims
=
True
).
astype
(
self
.
dtype
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
def
init_attr_type
(
self
):
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
soft_label
=
True
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
def
init_dtype_type
(
self
):
self
.
check_output
()
self
.
dtype
=
np
.
float32
def
init_bs_class_num
(
self
):
self
.
batch_size
=
5
self
.
class_num
=
37
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
check_grad
(
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
class
TestCrossEntropyOp3
(
OpTest
):
class
TestCrossEntropyOp3
(
TestCrossEntropyOp
):
"""Test cross-entropy with vectorized one-hot representation of labels.
"""Test cross-entropy with vectorized one-hot representation of labels.
"""
"""
def
setUp
(
self
):
def
init_label
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
label_index
=
np
.
random
.
randint
(
0
,
self
.
class_num
,
batch_size
=
5
(
self
.
batch_size
))
class_num
=
17
self
.
label
=
np
.
zeros
(
self
.
x
.
shape
).
astype
(
self
.
dtype
)
self
.
label
[
np
.
arange
(
self
.
batch_size
),
self
.
label_index
]
=
1
X
=
randomize_probability
(
batch_size
,
class_num
)
def
get_cross_entropy
(
self
):
label_index
=
np
.
random
.
randint
(
self
.
cross_entropy
=
np
.
asmatrix
(
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
[[
-
np
.
log
(
self
.
x
[
i
][
self
.
label_index
[
i
]])]
label
=
np
.
zeros
(
X
.
shape
)
for
i
in
range
(
self
.
x
.
shape
[
0
])]).
astype
(
self
.
dtype
)
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
cross_entropy
=
np
.
asmatrix
(
def
init_attr_type
(
self
):
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
self
.
soft_label
=
True
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
.
astype
(
np
.
float32
)}
def
init_dtype_type
(
self
):
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
dtype
=
np
.
float32
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
def
init_bs_class_num
(
self
):
self
.
check_output
()
self
.
batch_size
=
5
self
.
class_num
=
17
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
check_grad
(
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
class
TestCrossEntropyOp4
(
OpTest
):
class
TestCrossEntropyOp4
(
TestCrossEntropyOp
):
"""Test high rank tensor cross-entropy with discrete one-hot labels.
"""Test high rank tensor cross-entropy with discrete one-hot labels.
"""
"""
def
setUp
(
self
):
def
init_x
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
shape
=
[
10
,
2
,
4
]
shape
=
[
10
,
2
,
4
]
self
.
ins_num
=
np
.
prod
(
np
.
array
(
self
.
shape
))
ins_num
=
np
.
prod
(
np
.
array
(
shape
))
self
.
X_2d
=
randomize_probability
(
self
.
ins_num
,
class_num
=
10
self
.
class_num
).
astype
(
self
.
dtype
)
self
.
x
=
self
.
X_2d
.
reshape
(
self
.
shape
+
[
self
.
class_num
])
X_2d
=
randomize_probability
(
ins_num
,
class_num
,
dtype
=
'float64'
)
def
init_label
(
self
):
self
.
label_2d
=
np
.
random
.
randint
(
0
,
self
.
class_num
,
(
self
.
ins_num
,
1
),
dtype
=
"int64"
)
self
.
label
=
self
.
label_2d
.
reshape
(
self
.
shape
+
[
1
])
label_2d
=
np
.
random
.
randint
(
0
,
class_num
,
(
ins_num
,
1
),
dtype
=
"int64"
)
def
get_cross_entropy
(
self
):
cross_entropy_2d
=
np
.
asmatrix
(
cross_entropy_2d
=
np
.
asmatrix
(
[[
-
np
.
log
(
X_2d
[
i
][
label_2d
[
i
][
0
]])]
for
i
in
range
(
X_2d
.
shape
[
0
])],
[[
-
np
.
log
(
self
.
X_2d
[
i
][
self
.
label_2d
[
i
][
0
]])]
dtype
=
"float64"
)
for
i
in
range
(
self
.
X_2d
.
shape
[
0
])]).
astype
(
self
.
dtype
)
self
.
cross_entropy
=
np
.
array
(
cross_entropy_2d
).
reshape
(
self
.
shape
+
X
=
X_2d
.
reshape
(
shape
+
[
class_num
])
[
1
])
label
=
label_2d
.
reshape
(
shape
+
[
1
])
cross_entropy
=
np
.
array
(
cross_entropy_2d
).
reshape
(
shape
+
[
1
])
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
def
init_attr_type
(
self
):
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
soft_label
=
False
self
.
attrs
=
{
"soft_label"
:
False
}
def
test_check_output
(
self
):
def
init_dtype_type
(
self
):
self
.
check_output
()
self
.
dtype
=
np
.
float64
def
test_check_grad
(
self
):
def
init_bs_class_num
(
self
):
self
.
c
heck_grad
([
"X"
],
"Y"
,
numeric_grad_delta
=
0.001
)
self
.
c
lass_num
=
10
class
TestCrossEntropyOp5
(
OpTest
):
class
TestCrossEntropyOp5
(
TestCrossEntropyOp
):
"""Test high rank tensor cross-entropy with vectorized soft labels.
"""Test high rank tensor cross-entropy with vectorized soft labels.
"""
"""
def
setUp
(
self
):
def
init_x
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
shape
=
[
4
,
3
]
shape
=
[
4
,
3
]
self
.
ins_num
=
np
.
prod
(
np
.
array
(
self
.
shape
))
ins_num
=
np
.
prod
(
np
.
array
(
shape
))
self
.
X_2d
=
randomize_probability
(
self
.
ins_num
,
class_num
=
37
self
.
class_num
).
astype
(
self
.
dtype
)
self
.
x
=
self
.
X_2d
.
reshape
(
self
.
shape
+
[
self
.
class_num
])
X_2d
=
randomize_probability
(
ins_num
,
class_num
)
def
init_label
(
self
):
label_2d
=
np
.
random
.
uniform
(
0.1
,
1.0
,
self
.
label_2d
=
np
.
random
.
uniform
(
[
ins_num
,
class_num
]).
astype
(
"float32"
)
0.1
,
1.0
,
[
self
.
ins_num
,
self
.
class_num
]).
astype
(
self
.
dtype
)
label_2d
/=
label_2d
.
sum
(
axis
=
1
,
keepdims
=
True
)
self
.
label_2d
/=
self
.
label_2d
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy_2d
=
(
-
label_2d
*
np
.
log
(
X_2d
)).
sum
(
self
.
label
=
self
.
label_2d
.
reshape
(
self
.
shape
+
[
self
.
class_num
])
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
X
=
X_2d
.
reshape
(
shape
+
[
class_num
])
def
get_cross_entropy
(
self
):
label
=
label_2d
.
reshape
(
shape
+
[
class_num
])
cross_entropy_2d
=
(
-
self
.
label_2d
*
np
.
log
(
self
.
X_2d
)).
sum
(
cross_entropy
=
np
.
array
(
cross_entropy_2d
).
reshape
(
shape
+
[
1
])
axis
=
1
,
keepdims
=
True
).
astype
(
self
.
dtype
)
self
.
cross_entropy
=
np
.
array
(
cross_entropy_2d
).
reshape
(
self
.
shape
+
[
1
])
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
def
init_attr_type
(
self
):
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
soft_label
=
True
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
def
init_dtype_type
(
self
):
self
.
check_output
()
self
.
dtype
=
np
.
float32
def
init_bs_class_num
(
self
):
self
.
class_num
=
37
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
check_grad
(
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
class
TestCrossEntropyOp6
(
OpTest
):
class
TestCrossEntropyOp6
(
TestCrossEntropyOp
):
"""Test high rank tensor cross-entropy with vectorized one-hot representation of labels.
"""Test high rank tensor cross-entropy with vectorized one-hot representation of labels.
"""
"""
def
setUp
(
self
):
def
init_x
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
shape
=
[
4
,
3
,
2
]
shape
=
[
4
,
3
,
2
]
self
.
ins_num
=
np
.
prod
(
np
.
array
(
self
.
shape
))
ins_num
=
np
.
prod
(
np
.
array
(
shape
))
self
.
X_2d
=
randomize_probability
(
self
.
ins_num
,
class_num
=
17
self
.
class_num
).
astype
(
self
.
dtype
)
self
.
x
=
self
.
X_2d
.
reshape
(
self
.
shape
+
[
self
.
class_num
])
X_2d
=
randomize_probability
(
ins_num
,
class_num
)
label_index_2d
=
np
.
random
.
randint
(
def
init_label
(
self
):
0
,
class_num
,
(
ins_num
),
dtype
=
"int32"
)
self
.
label_index_2d
=
np
.
random
.
randint
(
label_2d
=
np
.
zeros
(
X_2d
.
shape
)
0
,
self
.
class_num
,
(
self
.
ins_num
),
dtype
=
"int64"
)
label_2d
[
np
.
arange
(
ins_num
),
label_index_2d
]
=
1
label_2d
=
np
.
zeros
(
self
.
X_2d
.
shape
)
label_2d
[
np
.
arange
(
self
.
ins_num
),
self
.
label_index_2d
]
=
1
self
.
label
=
label_2d
.
reshape
(
self
.
shape
+
[
self
.
class_num
]).
astype
(
self
.
dtype
)
def
get_cross_entropy
(
self
):
cross_entropy_2d
=
np
.
asmatrix
(
cross_entropy_2d
=
np
.
asmatrix
(
[[
-
np
.
log
(
X_2d
[
i
][
label_index_2d
[
i
]])]
[[
-
np
.
log
(
self
.
X_2d
[
i
][
self
.
label_index_2d
[
i
]])]
for
i
in
range
(
X_2d
.
shape
[
0
])],
for
i
in
range
(
self
.
X_2d
.
shape
[
0
])])
dtype
=
"float32"
)
self
.
cross_entropy
=
np
.
array
(
cross_entropy_2d
).
reshape
(
self
.
shape
+
[
1
]).
astype
(
self
.
dtype
)
X
=
X_2d
.
reshape
(
shape
+
[
class_num
])
def
init_attr_type
(
self
):
label
=
label_2d
.
reshape
(
shape
+
[
class_num
])
self
.
soft_label
=
True
cross_entropy
=
np
.
array
(
cross_entropy_2d
).
reshape
(
shape
+
[
1
])
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
.
astype
(
np
.
float32
)}
def
init_dtype_type
(
self
):
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
dtype
=
np
.
float32
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
def
init_bs_class_num
(
self
):
self
.
c
heck_output
()
self
.
c
lass_num
=
17
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
check_grad
(
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
class
TestCrossEntropyOp7
(
OpTest
):
class
TestCrossEntropyOp7
(
TestCrossEntropyOp
):
"""Test cross-entropy with ignore index.
"""Test cross-entropy with ignore index.
"""
"""
def
setUp
(
self
):
def
init_label
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
label
=
np
.
random
.
randint
(
batch_size
=
30
0
,
self
.
class_num
,
(
self
.
batch_size
,
1
),
dtype
=
"int64"
)
class_num
=
10
ignore_index
=
3
X
=
randomize_probability
(
batch_size
,
class_num
,
dtype
=
'float64'
)
def
get_cross_entropy
(
self
):
self
.
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
self
.
x
[
i
][
self
.
label
[
i
][
0
]])]
if
self
.
label
[
i
][
0
]
!=
self
.
ignore_index
else
[
0
]
for
i
in
range
(
self
.
x
.
shape
[
0
])]).
astype
(
self
.
dtype
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int64"
)
def
init_attr_type
(
self
):
cross_entropy
=
np
.
asmatrix
(
self
.
soft_label
=
False
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
self
.
ignore_index
=
3
if
label
[
i
][
0
]
!=
ignore_index
else
[
0
]
for
i
in
range
(
X
.
shape
[
0
])],
def
init_dtype_type
(
self
):
dtype
=
"float64"
)
self
.
dtype
=
np
.
float64
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
def
init_bs_class_num
(
self
):
self
.
attrs
=
{
"soft_label"
:
False
,
"ignore_index"
:
ignore_index
}
self
.
batch_size
=
30
self
.
class_num
=
10
# Add Fp16 test
def
create_test_class
(
parent
,
cls_name
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCrossEntropyFP16Op
(
parent
):
def
init_dtype_type
(
self
):
return
np
.
float16
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-1
)
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
,
numeric_grad_delta
=
0.001
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Y'
,
max_relative_error
=
0.9
)
cls_name
=
"{0}"
.
format
(
cls_name
)
TestCrossEntropyFP16Op
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestCrossEntropyFP16Op
create_test_class
(
TestCrossEntropyOp
,
"TestCrossEntropyF16Op"
)
#create_test_class(TestCrossEntropyOp2, "TestCrossEntropyF16Op2")
create_test_class
(
TestCrossEntropyOp3
,
"TestCrossEntropyF16Op3"
)
create_test_class
(
TestCrossEntropyOp4
,
"TestCrossEntropyF16Op4"
)
#create_test_class(TestCrossEntropyOp5, "TestCrossEntropyF16Op5")
create_test_class
(
TestCrossEntropyOp6
,
"TestCrossEntropyF16Op6"
)
create_test_class
(
TestCrossEntropyOp7
,
"TestCrossEntropyF16Op7"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_mean_op.py
浏览文件 @
25e070ec
...
@@ -17,14 +17,20 @@ from __future__ import print_function
...
@@ -17,14 +17,20 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
class
TestMeanOp
(
OpTest
):
class
TestMeanOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"mean"
self
.
op_type
=
"mean"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
"float32"
)}
self
.
dtype
=
np
.
float32
self
.
init_dtype_type
()
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
self
.
dtype
)}
self
.
outputs
=
{
'Out'
:
np
.
mean
(
self
.
inputs
[
"X"
])}
self
.
outputs
=
{
'Out'
:
np
.
mean
(
self
.
inputs
[
"X"
])}
def
init_dtype_type
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -32,5 +38,23 @@ class TestMeanOp(OpTest):
...
@@ -32,5 +38,23 @@ class TestMeanOp(OpTest):
self
.
check_grad
([
'X'
],
'Out'
)
self
.
check_grad
([
'X'
],
'Out'
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFP16MeanOp
(
TestMeanOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-3
)
def
test_checkout_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
max_relative_error
=
0.8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_mul_op.py
浏览文件 @
25e070ec
...
@@ -23,12 +23,17 @@ from op_test import OpTest
...
@@ -23,12 +23,17 @@ from op_test import OpTest
class
TestMulOp
(
OpTest
):
class
TestMulOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"mul"
self
.
op_type
=
"mul"
self
.
dtype
=
np
.
float32
self
.
init_dtype_type
()
self
.
inputs
=
{
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
2
,
5
)).
astype
(
"float32"
),
'X'
:
np
.
random
.
random
((
2
,
5
)).
astype
(
self
.
dtype
),
'Y'
:
np
.
random
.
random
((
5
,
3
)).
astype
(
"float32"
)
'Y'
:
np
.
random
.
random
((
5
,
3
)).
astype
(
self
.
dtype
)
}
}
self
.
outputs
=
{
'Out'
:
np
.
dot
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
self
.
outputs
=
{
'Out'
:
np
.
dot
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
def
init_dtype_type
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -47,9 +52,11 @@ class TestMulOp(OpTest):
...
@@ -47,9 +52,11 @@ class TestMulOp(OpTest):
class
TestMulOp2
(
OpTest
):
class
TestMulOp2
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"mul"
self
.
op_type
=
"mul"
self
.
dtype
=
np
.
float32
self
.
init_dtype_type
()
self
.
inputs
=
{
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
3
,
4
,
4
,
3
)).
astype
(
"float32"
),
'X'
:
np
.
random
.
random
((
3
,
4
,
4
,
3
)).
astype
(
self
.
dtype
),
'Y'
:
np
.
random
.
random
((
2
,
6
,
1
,
2
,
3
)).
astype
(
"float32"
)
'Y'
:
np
.
random
.
random
((
2
,
6
,
1
,
2
,
3
)).
astype
(
self
.
dtype
)
}
}
self
.
attrs
=
{
self
.
attrs
=
{
'x_num_col_dims'
:
2
,
'x_num_col_dims'
:
2
,
...
@@ -60,6 +67,9 @@ class TestMulOp2(OpTest):
...
@@ -60,6 +67,9 @@ class TestMulOp2(OpTest):
result
=
result
.
reshape
(
3
,
4
,
1
,
2
,
3
)
result
=
result
.
reshape
(
3
,
4
,
1
,
2
,
3
)
self
.
outputs
=
{
'Out'
:
result
}
self
.
outputs
=
{
'Out'
:
result
}
def
init_dtype_type
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -75,41 +85,77 @@ class TestMulOp2(OpTest):
...
@@ -75,41 +85,77 @@ class TestMulOp2(OpTest):
[
'X'
],
'Out'
,
max_relative_error
=
0.5
,
no_grad_set
=
set
(
'Y'
))
[
'X'
],
'Out'
,
max_relative_error
=
0.5
,
no_grad_set
=
set
(
'Y'
))
class
TestFP16MulOp1
(
OpTest
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
def
setUp
(
self
):
"core is not compiled with CUDA"
)
self
.
op_type
=
"mul"
class
TestFP16MulOp1
(
TestMulOp
):
x
=
np
.
random
.
random
((
3
,
5
)).
astype
(
"float16"
)
def
init_dtype_type
(
self
):
y
=
np
.
random
.
random
((
5
,
4
)).
astype
(
"float16"
)
self
.
dtype
=
np
.
float16
self
.
inputs
=
{
'X'
:
x
.
view
(
np
.
float16
),
'Y'
:
y
.
view
(
np
.
float16
)}
self
.
outputs
=
{
'Out'
:
np
.
dot
(
x
,
y
)}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-1
)
self
.
check_output_with_place
(
place
,
atol
=
1e-1
)
def
test_check_grad_normal
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Y'
],
'Out'
,
max_relative_error
=
0.5
)
class
TestFP16MulOp2
(
OpTest
):
def
test_check_grad_ingore_x
(
self
):
def
setUp
(
self
):
place
=
core
.
CUDAPlace
(
0
)
self
.
op_type
=
"mul"
if
core
.
is_float16_supported
(
place
):
x
=
np
.
random
.
random
((
3
,
4
,
4
,
3
)).
astype
(
"float16"
)
self
.
check_grad_with_place
(
y
=
np
.
random
.
random
((
2
,
6
,
1
,
2
,
3
)).
astype
(
"float16"
)
place
,
[
'Y'
],
self
.
inputs
=
{
'X'
:
x
.
view
(
np
.
float16
),
'Y'
:
y
.
view
(
np
.
float16
)}
'Out'
,
self
.
attrs
=
{
max_relative_error
=
0.5
,
'x_num_col_dims'
:
2
,
no_grad_set
=
set
(
"X"
))
'y_num_col_dims'
:
2
,
}
def
test_check_grad_ingore_y
(
self
):
result
=
np
.
dot
(
x
.
reshape
(
3
*
4
,
4
*
3
),
y
.
reshape
(
2
*
6
,
1
*
2
*
3
))
place
=
core
.
CUDAPlace
(
0
)
result
=
result
.
reshape
(
3
,
4
,
1
,
2
,
3
)
if
core
.
is_float16_supported
(
place
):
self
.
outputs
=
{
'Out'
:
result
}
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
max_relative_error
=
0.5
,
no_grad_set
=
set
(
'Y'
))
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestFP16MulOp2
(
TestMulOp2
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-1
)
self
.
check_output_with_place
(
place
,
atol
=
2e-1
)
def
test_check_grad_normal
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
,
'Y'
],
'Out'
,
max_relative_error
=
0.9
)
def
test_check_grad_ingore_x
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'Y'
],
'Out'
,
max_relative_error
=
0.5
,
no_grad_set
=
set
(
"X"
))
def
test_check_grad_ingore_y
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
'X'
],
'Out'
,
max_relative_error
=
0.9
,
no_grad_set
=
set
(
'Y'
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_pool2d_mkldnn_op.py
浏览文件 @
25e070ec
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
from
test_pool2d_op
import
TestPool2
d
_Op
,
TestCase1
,
TestCase2
,
TestCase3
,
TestCase4
,
TestCase5
from
test_pool2d_op
import
TestPool2
D
_Op
,
TestCase1
,
TestCase2
,
TestCase3
,
TestCase4
,
TestCase5
class
TestMKLDNNCase1
(
TestPool2
d
_Op
):
class
TestMKLDNNCase1
(
TestPool2
D
_Op
):
def
init_kernel_type
(
self
):
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
self
.
use_mkldnn
=
True
...
...
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
25e070ec
...
@@ -81,7 +81,7 @@ def avg_pool2D_forward_naive(x,
...
@@ -81,7 +81,7 @@ def avg_pool2D_forward_naive(x,
return
out
return
out
class
TestPool2
d
_Op
(
OpTest
):
class
TestPool2
D
_Op
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"pool2d"
self
.
op_type
=
"pool2d"
self
.
use_cudnn
=
False
self
.
use_cudnn
=
False
...
@@ -160,7 +160,7 @@ class TestPool2d_Op(OpTest):
...
@@ -160,7 +160,7 @@ class TestPool2d_Op(OpTest):
self
.
exclusive
=
True
self
.
exclusive
=
True
class
TestCase1
(
TestPool2
d
_Op
):
class
TestCase1
(
TestPool2
D
_Op
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
shape
=
[
2
,
3
,
7
,
7
]
self
.
shape
=
[
2
,
3
,
7
,
7
]
self
.
ksize
=
[
3
,
3
]
self
.
ksize
=
[
3
,
3
]
...
@@ -175,7 +175,7 @@ class TestCase1(TestPool2d_Op):
...
@@ -175,7 +175,7 @@ class TestCase1(TestPool2d_Op):
self
.
global_pool
=
False
self
.
global_pool
=
False
class
TestCase2
(
TestPool2
d
_Op
):
class
TestCase2
(
TestPool2
D
_Op
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
shape
=
[
2
,
3
,
7
,
7
]
self
.
shape
=
[
2
,
3
,
7
,
7
]
self
.
ksize
=
[
3
,
3
]
self
.
ksize
=
[
3
,
3
]
...
@@ -190,7 +190,7 @@ class TestCase2(TestPool2d_Op):
...
@@ -190,7 +190,7 @@ class TestCase2(TestPool2d_Op):
self
.
global_pool
=
False
self
.
global_pool
=
False
class
TestCase3
(
TestPool2
d
_Op
):
class
TestCase3
(
TestPool2
D
_Op
):
def
init_pool_type
(
self
):
def
init_pool_type
(
self
):
self
.
pool_type
=
"max"
self
.
pool_type
=
"max"
self
.
pool2D_forward_naive
=
max_pool2D_forward_naive
self
.
pool2D_forward_naive
=
max_pool2D_forward_naive
...
@@ -208,47 +208,35 @@ class TestCase5(TestCase2):
...
@@ -208,47 +208,35 @@ class TestCase5(TestCase2):
self
.
pool2D_forward_naive
=
max_pool2D_forward_naive
self
.
pool2D_forward_naive
=
max_pool2D_forward_naive
#--------------------test pool2d--------------------
#--------------------test pool2d cudnn--------------------
class
TestCUDNNCase1
(
TestPool2d_Op
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNCase1
(
TestPool2d_Op
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestCUDNNCase2
(
TestCase1
):
def
create_test_cudnn_class
(
parent
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNCase
(
parent
):
def
init_kernel_type
(
self
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
use_cudnn
=
True
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"CUDNNOp"
)
TestCUDNNCase
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestCUDNNCase
class
TestFP16CUDNNCase2
(
TestCase1
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
create_test_cudnn_class
(
TestPool2D_Op
)
create_test_cudnn_class
(
TestCase1
)
create_test_cudnn_class
(
TestCase2
)
create_test_cudnn_class
(
TestCase3
)
create_test_cudnn_class
(
TestCase4
)
create_test_cudnn_class
(
TestCase5
)
class
TestCUDNNCase3
(
TestCase2
):
#--------------------test pool2d cudnn_fp16--------------------
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNCase3
(
TestCase2
):
def
create_test_cudnn_fp16_class
(
parent
,
check_grad
=
True
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestCUDNNFp16Case
(
parent
):
def
init_kernel_type
(
self
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
self
.
dtype
=
np
.
float16
...
@@ -259,76 +247,59 @@ class TestFP16CUDNNCase3(TestCase2):
...
@@ -259,76 +247,59 @@ class TestFP16CUDNNCase3(TestCase2):
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
def
test_check_grad
(
self
):
class
TestCUDNNCase4
(
TestCase3
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNCase4
(
TestCase3
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
place
)
and
self
.
pool_type
!=
"max"
and
check_grad
:
self
.
check_grad_with_place
(
place
,
set
([
'X'
]),
'Out'
,
max_relative_error
=
0.07
)
class
TestCUDNNCase5
(
TestCase4
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNCase5
(
TestCase4
):
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"CUDNNFp16Op"
)
def
init_kernel_type
(
self
):
TestCUDNNFp16Case
.
__name__
=
cls_name
self
.
use_cudnn
=
True
globals
()[
cls_name
]
=
TestCUDNNFp16Case
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
create_test_cudnn_fp16_class
(
TestPool2D_Op
)
create_test_cudnn_fp16_class
(
TestCase1
,
check_grad
=
False
)
create_test_cudnn_fp16_class
(
TestCase2
)
create_test_cudnn_fp16_class
(
TestCase3
)
create_test_cudnn_fp16_class
(
TestCase4
)
create_test_cudnn_fp16_class
(
TestCase5
)
class
TestCUDNNCase6
(
TestCase5
):
#--------------------test pool2d use ceil mode--------------------
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
class
TestFP16CUDNNCase6
(
TestCase5
):
def
create_test_cudnn_use_ceil_class
(
parent
):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestPool2DUseCeilCase
(
parent
):
def
init_kernel_type
(
self
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
self
.
use_cudnn
=
True
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
class
TestCeilModeCase1
(
TestCUDNNCase1
):
def
init_ceil_mode
(
self
):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
True
self
.
ceil_mode
=
True
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"CUDNNOpCeilMode"
)
TestPool2DUseCeilCase
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestPool2DUseCeilCase
class
TestCeilModeCase2
(
TestCUDNNCase2
):
def
init_ceil_mode
(
self
):
create_test_cudnn_use_ceil_class
(
TestPool2D_Op
)
self
.
ceil_mode
=
True
create_test_cudnn_use_ceil_class
(
TestCase1
)
class
TestCeilModeCase3
(
TestCase1
):
def
create_test_use_ceil_class
(
parent
):
class
TestPool2DUseCeilCase
(
parent
):
def
init_ceil_mode
(
self
):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
True
self
.
ceil_mode
=
True
cls_name
=
"{0}_{1}"
.
format
(
parent
.
__name__
,
"CeilModeCast"
)
TestPool2DUseCeilCase
.
__name__
=
cls_name
globals
()[
cls_name
]
=
TestPool2DUseCeilCase
class
TestCeilModeCase4
(
TestCase2
):
def
init_ceil_mode
(
self
):
create_test_use_ceil_class
(
TestCase1
)
self
.
ceil_mode
=
True
create_test_use_ceil_class
(
TestCase2
)
class
TestAvgInclude
(
TestCase2
):
class
TestAvgInclude
(
TestCase2
):
...
@@ -336,7 +307,10 @@ class TestAvgInclude(TestCase2):
...
@@ -336,7 +307,10 @@ class TestAvgInclude(TestCase2):
self
.
exclusive
=
False
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
class
TestCUDNNAvgInclude
(
TestCase2
):
def
init_kernel_type
(
self
):
self
.
use_cudnn
=
True
def
init_exclusive
(
self
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
self
.
exclusive
=
False
...
...
python/paddle/fluid/tests/unittests/test_scale_op.py
浏览文件 @
25e070ec
...
@@ -24,9 +24,16 @@ from paddle.fluid.op import Operator
...
@@ -24,9 +24,16 @@ from paddle.fluid.op import Operator
class
TestScaleOp
(
OpTest
):
class
TestScaleOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"scale"
self
.
op_type
=
"scale"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
"float32"
)}
self
.
dtype
=
np
.
float32
self
.
init_dtype_type
()
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
self
.
dtype
)}
self
.
attrs
=
{
'scale'
:
-
2.3
}
self
.
attrs
=
{
'scale'
:
-
2.3
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
attrs
[
'scale'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
dtype
(
self
.
attrs
[
'scale'
])
}
def
init_dtype_type
(
self
):
pass
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -36,9 +43,15 @@ class TestScaleOp(OpTest):
...
@@ -36,9 +43,15 @@ class TestScaleOp(OpTest):
class
TestScaleOpSelectedRows
(
unittest
.
TestCase
):
class
TestScaleOpSelectedRows
(
unittest
.
TestCase
):
def
init_dtype_type
(
self
):
pass
def
check_with_place
(
self
,
place
,
in_name
,
out_name
):
def
check_with_place
(
self
,
place
,
in_name
,
out_name
):
scope
=
core
.
Scope
()
scope
=
core
.
Scope
()
self
.
dtype
=
np
.
float32
self
.
init_dtype_type
()
# create and initialize Grad Variable
# create and initialize Grad Variable
in_height
=
10
in_height
=
10
in_rows
=
[
0
,
4
,
7
]
in_rows
=
[
0
,
4
,
7
]
...
@@ -49,7 +62,7 @@ class TestScaleOpSelectedRows(unittest.TestCase):
...
@@ -49,7 +62,7 @@ class TestScaleOpSelectedRows(unittest.TestCase):
in_selected_rows
.
set_height
(
in_height
)
in_selected_rows
.
set_height
(
in_height
)
in_selected_rows
.
set_rows
(
in_rows
)
in_selected_rows
.
set_rows
(
in_rows
)
in_array
=
np
.
random
.
random
(
in_array
=
np
.
random
.
random
(
(
len
(
in_rows
),
in_row_numel
)).
astype
(
"float32"
)
(
len
(
in_rows
),
in_row_numel
)).
astype
(
self
.
dtype
)
in_tensor
=
in_selected_rows
.
get_tensor
()
in_tensor
=
in_selected_rows
.
get_tensor
()
in_tensor
.
set
(
in_array
,
place
)
in_tensor
.
set
(
in_array
,
place
)
...
@@ -87,5 +100,41 @@ class TestScaleOpSelectedRows(unittest.TestCase):
...
@@ -87,5 +100,41 @@ class TestScaleOpSelectedRows(unittest.TestCase):
self
.
check_with_place
(
place
,
'in'
,
'in'
)
self
.
check_with_place
(
place
,
'in'
,
'in'
)
# Add FP16 test
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestScaleFp16Op
(
TestScaleOp
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
0.002
)
def
test_check_grad
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
place
,
[
"X"
],
"Out"
,
max_relative_error
=
0.05
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestScaleFp16OpSelectedRows
(
TestScaleOpSelectedRows
):
def
init_dtype_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_scale_selected_rows
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_with_place
(
place
,
'in'
,
'out'
)
def
test_scale_selected_rows_inplace
(
self
):
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_with_place
(
place
,
'in'
,
'in'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_softmax_op.py
浏览文件 @
25e070ec
...
@@ -62,10 +62,9 @@ class TestSoftmaxOp(OpTest):
...
@@ -62,10 +62,9 @@ class TestSoftmaxOp(OpTest):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
if
self
.
use_cudnn
or
self
.
dtype
==
np
.
float16
:
return
if
self
.
use_cudnn
:
place
=
core
.
CUDAPlace
(
0
)
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad_with_place
(
self
.
check_grad_with_place
(
place
,
[
"X"
],
"Out"
,
max_relative_error
=
0.01
)
place
,
[
"X"
],
"Out"
,
max_relative_error
=
0.01
)
else
:
else
:
...
@@ -103,10 +102,23 @@ class TestSoftmaxFP16Op(TestSoftmaxOp):
...
@@ -103,10 +102,23 @@ class TestSoftmaxFP16Op(TestSoftmaxOp):
if
core
.
is_float16_supported
(
place
):
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
# FIXME: If the x_shape is [10, 10], gradient failed.
def
test_check_grad
(
self
):
pass
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
"core is not compiled with CUDA"
)
class
TestSoftmaxFP16Op2
(
TestSoftmaxFP16Op
):
class
TestSoftmaxFP16Op2
(
TestSoftmaxOp
):
def
init_kernel_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
1e-3
)
def
get_x_shape
(
self
):
def
get_x_shape
(
self
):
return
[
2
,
3
,
4
,
5
]
return
[
2
,
3
,
4
,
5
]
...
...
python/paddle/fluid/tests/unittests/test_sum_op.py
浏览文件 @
25e070ec
...
@@ -24,16 +24,20 @@ from paddle.fluid.op import Operator
...
@@ -24,16 +24,20 @@ from paddle.fluid.op import Operator
class
TestSumOp
(
OpTest
):
class
TestSumOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"sum"
self
.
op_type
=
"sum"
self
.
init_kernel_type
()
self
.
use_mkldnn
=
False
self
.
use_mkldnn
=
False
self
.
init_kernel_type
()
self
.
init_kernel_type
()
x0
=
np
.
random
.
random
((
3
,
4
)).
astype
(
'float32'
)
x0
=
np
.
random
.
random
((
3
,
4
)).
astype
(
self
.
dtype
)
x1
=
np
.
random
.
random
((
3
,
4
)).
astype
(
'float32'
)
x1
=
np
.
random
.
random
((
3
,
4
)).
astype
(
self
.
dtype
)
x2
=
np
.
random
.
random
((
3
,
4
)).
astype
(
'float32'
)
x2
=
np
.
random
.
random
((
3
,
4
)).
astype
(
self
.
dtype
)
self
.
inputs
=
{
"X"
:
[(
"x0"
,
x0
),
(
"x1"
,
x1
),
(
"x2"
,
x2
)]}
self
.
inputs
=
{
"X"
:
[(
"x0"
,
x0
),
(
"x1"
,
x1
),
(
"x2"
,
x2
)]}
y
=
x0
+
x1
+
x2
y
=
x0
+
x1
+
x2
self
.
outputs
=
{
'Out'
:
y
}
self
.
outputs
=
{
'Out'
:
y
}
self
.
attrs
=
{
'use_mkldnn'
:
self
.
use_mkldnn
}
self
.
attrs
=
{
'use_mkldnn'
:
self
.
use_mkldnn
}
def
init_kernel_type
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -59,8 +63,11 @@ class TestSelectedRowsSumOp(OpTest):
...
@@ -59,8 +63,11 @@ class TestSelectedRowsSumOp(OpTest):
self
.
check_input_and_optput
(
core
.
Scope
(),
place
,
inplace
,
False
,
False
,
self
.
check_input_and_optput
(
core
.
Scope
(),
place
,
inplace
,
False
,
False
,
False
)
False
)
def
init_kernel_type
(
self
):
self
.
dtype
=
np
.
float32
def
_get_array
(
self
,
row_num
,
row_numel
):
def
_get_array
(
self
,
row_num
,
row_numel
):
array
=
np
.
ones
((
row_num
,
row_numel
)).
astype
(
"float32"
)
array
=
np
.
ones
((
row_num
,
row_numel
)).
astype
(
self
.
dtype
)
for
i
in
range
(
row_num
):
for
i
in
range
(
row_num
):
array
[
i
]
*=
i
array
[
i
]
*=
i
return
array
return
array
...
@@ -129,5 +136,36 @@ class TestSelectedRowsSumOp(OpTest):
...
@@ -129,5 +136,36 @@ class TestSelectedRowsSumOp(OpTest):
self
.
check_with_place
(
place
,
inplace
)
self
.
check_with_place
(
place
,
inplace
)
class
TestFP16SumOp
(
TestSumOp
):
def
init_kernel_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_output_with_place
(
place
,
atol
=
2e-2
)
# FIXME: Because of the precision fp16, max_relative_error
# should be 0.15 here.
def
test_check_grad
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
self
.
check_grad
([
'x0'
],
'Out'
,
max_relative_error
=
0.15
)
class
TestFP16SelectedRowsSumOp
(
TestSelectedRowsSumOp
):
def
init_kernel_type
(
self
):
self
.
dtype
=
np
.
float16
def
test_w_is_selected_rows
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
if
core
.
is_float16_supported
(
place
):
for
inplace
in
[
True
,
False
]:
self
.
check_with_place
(
place
,
inplace
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录