Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5271c32d
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5271c32d
编写于
3月 19, 2018
作者:
K
Kexin Zhao
提交者:
GitHub
3月 19, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9223 from kexinzhao/dropout_fp16
Add float16 support to dropout operator
上级
832deee4
509c8399
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
184 addition
and
38 deletion
+184
-38
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+3
-6
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+14
-13
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+1
-1
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+133
-18
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+33
-0
未找到文件。
paddle/fluid/operators/dropout_op.cc
浏览文件 @
5271c32d
...
@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
...
@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
}
}
};
};
template
<
typename
AttrType
>
class
DropoutOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
DropoutOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
DropoutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
DropoutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
...
@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
}
}
};
};
template
<
typename
AttrType
>
class
DropoutOpGrad
:
public
framework
::
OperatorWithKernel
{
class
DropoutOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
...
@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
dropout
,
ops
::
DropoutOp
,
ops
::
DropoutOpMaker
<
float
>
,
dropout_grad
,
REGISTER_OP
(
dropout
,
ops
::
DropoutOp
,
ops
::
DropoutOpMaker
,
dropout_grad
,
ops
::
DropoutOpGrad
<
float
>
);
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
dropout
,
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
paddle/fluid/operators/dropout_op.cu
浏览文件 @
5271c32d
...
@@ -18,17 +18,18 @@ limitations under the License. */
...
@@ -18,17 +18,18 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <thrust/transform.h>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
,
typename
AttrType
>
template
<
typename
T
>
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
AttrType
dropout_prob
,
const
T
*
src
,
const
float
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
T
*
mask_data
,
T
*
dst
)
{
thrust
::
minstd_rand
rng
;
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
AttrType
>
dist
(
0
,
1
);
thrust
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(;
idx
<
n
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
...
@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
// implement uniform random.
template
<
typename
Place
,
typename
T
,
typename
AttrType
>
template
<
typename
Place
,
typename
T
>
class
GPUDropoutKernel
:
public
framework
::
OpKernel
<
T
>
{
class
GPUDropoutKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
AttrType
dropout_prob
=
context
.
Attr
<
AttrType
>
(
"dropout_prob"
);
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
...
@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
...
@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int
threads
=
512
;
int
threads
=
512
;
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
,
AttrType
><<<
grid
,
threads
,
0
,
RandomGenerator
<
context
.
cuda_device_context
().
stream
()
>>>
(
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
}
else
{
}
else
{
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
}
};
};
...
@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
...
@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
dropout
,
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
float
>
);
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/dropout_op.h
浏览文件 @
5271c32d
...
@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
...
@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename
IndexType
=
Eigen
::
DenseIndex
>
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
DeviceContext
,
typename
T
,
typename
AttrType
>
template
<
typename
DeviceContext
,
typename
T
>
class
CPUDropoutKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CPUDropoutKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
...
...
paddle/fluid/platform/float16.h
浏览文件 @
5271c32d
...
@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
...
@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16
#endif // PADDLE_CUDA_FP16
// Arithmetic operators on ARMv8.2-A CPU
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_WITH_NATIVE_FP16)
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hadd
(
half
(
a
),
half
(
b
)));
#else
return
float16
(
float
(
a
)
+
float
(
b
));
#endif
}
HOSTDEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hsub
(
half
(
a
),
half
(
b
)));
#else
return
float16
(
float
(
a
)
-
float
(
b
));
#endif
}
HOSTDEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hmul
(
half
(
a
),
half
(
b
)));
#else
return
float16
(
float
(
a
)
*
float
(
b
));
#endif
}
HOSTDEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float
num
=
__half2float
(
half
(
a
));
float
denom
=
__half2float
(
half
(
b
));
return
float16
(
num
/
denom
);
#else
return
float16
(
float
(
a
)
/
float
(
b
));
#endif
}
HOSTDEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
float16
(
__hneg
(
half
(
a
)));
#else
float16
res
;
res
.
x
=
a
.
x
^
0x8000
;
return
res
;
#endif
}
HOSTDEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
+
b
;
return
a
;
}
HOSTDEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
-
b
;
return
a
;
}
HOSTDEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
*
b
;
return
a
;
}
HOSTDEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
a
/
b
;
return
a
;
}
HOSTDEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__heq
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
==
float
(
b
);
#endif
}
HOSTDEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hne
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
!=
float
(
b
);
#endif
}
HOSTDEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hlt
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
<
float
(
b
);
#endif
}
HOSTDEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hle
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
<=
float
(
b
);
#endif
}
HOSTDEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hgt
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
>
float
(
b
);
#endif
}
HOSTDEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hge
(
half
(
a
),
half
(
b
));
#else
return
float
(
a
)
>=
float
(
b
);
#endif
}
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
float16
res
;
float16
res
;
asm
volatile
(
asm
volatile
(
...
@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
...
@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return
(
res
&
0xffff
)
!=
0
;
return
(
res
&
0xffff
)
!=
0
;
}
}
// Arithmetic operators, software emulated on other CPU
// Arithmetic operators
for float16
, software emulated on other CPU
#else
#else
HOST
DEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
+
float
(
b
));
return
float16
(
float
(
a
)
+
float
(
b
));
}
}
HOST
DEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
-
float
(
b
));
return
float16
(
float
(
a
)
-
float
(
b
));
}
}
HOST
DEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
*
float
(
b
));
return
float16
(
float
(
a
)
*
float
(
b
));
}
}
HOST
DEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
float
(
a
)
/
float
(
b
));
return
float16
(
float
(
a
)
/
float
(
b
));
}
}
HOST
DEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
HOST
inline
float16
operator
-
(
const
float16
&
a
)
{
float16
res
;
float16
res
;
res
.
x
=
a
.
x
^
0x8000
;
res
.
x
=
a
.
x
^
0x8000
;
return
res
;
return
res
;
}
}
HOST
DEVICE
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
+=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
+
float
(
b
));
a
=
float16
(
float
(
a
)
+
float
(
b
));
return
a
;
return
a
;
}
}
HOST
DEVICE
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
-=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
-
float
(
b
));
a
=
float16
(
float
(
a
)
-
float
(
b
));
return
a
;
return
a
;
}
}
HOST
DEVICE
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
*=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
*
float
(
b
));
a
=
float16
(
float
(
a
)
*
float
(
b
));
return
a
;
return
a
;
}
}
HOST
DEVICE
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
float16
&
operator
/=
(
float16
&
a
,
const
float16
&
b
)
{
a
=
float16
(
float
(
a
)
/
float
(
b
));
a
=
float16
(
float
(
a
)
/
float
(
b
));
return
a
;
return
a
;
}
}
HOST
DEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
==
float
(
b
);
return
float
(
a
)
==
float
(
b
);
}
}
HOST
DEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
!=
float
(
b
);
return
float
(
a
)
!=
float
(
b
);
}
}
HOST
DEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
<
float
(
b
);
return
float
(
a
)
<
float
(
b
);
}
}
HOST
DEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
<=
float
(
b
);
return
float
(
a
)
<=
float
(
b
);
}
}
HOST
DEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
>
float
(
b
);
return
float
(
a
)
>
float
(
b
);
}
}
HOST
DEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
HOST
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float
(
a
)
>=
float
(
b
);
return
float
(
a
)
>=
float
(
b
);
}
}
#endif
#endif
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
5271c32d
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest):
...
@@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest):
self
.
check_output
()
self
.
check_output
()
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
init_test_case
()
x
=
np
.
random
.
random
(
self
.
input_size
).
astype
(
"float16"
)
out
=
x
*
(
1.0
-
self
.
prob
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
self
.
attrs
=
{
'dropout_prob'
:
self
.
prob
,
'fix_seed'
:
self
.
fix_seed
,
'is_test'
:
True
}
self
.
outputs
=
{
'Out'
:
out
}
def
init_test_case
(
self
):
self
.
input_size
=
[
32
,
64
]
self
.
prob
=
0.35
self
.
fix_seed
=
True
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
"dropout"
):
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
),
atol
=
1e-3
)
class
TestFP16DropoutOp2
(
TestFP16DropoutOp
):
def
init_test_case
(
self
):
self
.
input_size
=
[
32
,
64
,
3
]
self
.
prob
=
0.75
self
.
fix_seed
=
False
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录