Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
379d933a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
379d933a
编写于
10月 26, 2018
作者:
H
Hongyu Liu
提交者:
GitHub
10月 26, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14036 from phlrain/add_dropout_att_new
Add dropout att new 1.1 merge
上级
18be7256
a4ad286e
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
174 addition
and
28 deletion
+174
-28
.gitignore
.gitignore
+1
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+28
-2
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+22
-7
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+16
-3
paddle/fluid/operators/softmax_cudnn_op.cu.cc
paddle/fluid/operators/softmax_cudnn_op.cu.cc
+3
-1
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+8
-5
paddle/fluid/operators/transpose_op.cu.cc
paddle/fluid/operators/transpose_op.cu.cc
+8
-5
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+1
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+23
-2
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+63
-0
未找到文件。
.gitignore
浏览文件 @
379d933a
...
...
@@ -28,3 +28,4 @@ third_party/
build_*
# clion workspace.
cmake-build-*
model_test
paddle/fluid/API.spec
浏览文件 @
379d933a
...
...
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
], varargs=None, keywords=None, defaults=(False, None, None
))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
, 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'
))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...
...
paddle/fluid/operators/dropout_op.cc
浏览文件 @
379d933a
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
...
...
@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"dropout_implementation"
,
"[
\"
downgrade_in_infer
\"
|
\"
upscale_in_train
\"
]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient"
)
.
SetDefault
(
"downgrade_in_infer"
)
.
AddCustomChecker
([](
const
std
::
string
&
type
)
{
PADDLE_ENFORCE
(
type
==
"downgrade_in_infer"
||
type
==
"upscale_in_train"
,
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"
);
});
AddComment
(
R"DOC(
Dropout Operator.
...
...
@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.cu
浏览文件 @
379d933a
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -26,7 +27,8 @@ namespace operators {
template
<
typename
T
>
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
float
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
T
*
mask_data
,
T
*
dst
,
bool
is_upscale_in_train
)
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
...
...
@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if
(
dist
(
rng
)
<
dropout_prob
)
{
mask
=
static_cast
<
T
>
(
0
);
}
else
{
mask
=
static_cast
<
T
>
(
1
);
if
(
is_upscale_in_train
)
{
mask
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
}
else
{
mask
=
static_cast
<
T
>
(
1
);
}
}
dest
=
s
*
mask
;
mask_data
[
idx
]
=
mask
;
...
...
@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
...
...
@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
(
dropout_implementation
==
"upscale_in_train"
));
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.h
浏览文件 @
379d933a
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
if
(
dropout_implementation
==
"upscale_in_train"
)
{
mask_data
[
i
]
=
1.0
f
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
...
...
@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
paddle/fluid/operators/softmax_cudnn_op.cu.cc
浏览文件 @
379d933a
...
...
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
softmax
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxCUDNNKernel
<
float
>
,
ops
::
SoftmaxCUDNNKernel
<
double
>
,
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
);
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
double
>
);
paddle/fluid/operators/transpose_op.cc
浏览文件 @
379d933a
...
...
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
ops
::
Transpose2OpMaker
,
ops
::
Transpose2GradMaker
);
REGISTER_OPERATOR
(
transpose2_grad
,
ops
::
Transpose2OpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/transpose_op.cu.cc
浏览文件 @
379d933a
...
...
@@ -16,15 +16,18 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
python/paddle/fluid/clip.py
浏览文件 @
379d933a
...
...
@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
)
square
=
grad
*
grad
local_norm_var
=
layers
.
cast
(
layers
.
reduce_sum
(
input
=
square
),
'float64'
)
local_norm_var
=
layers
.
reduce_sum
(
input
=
square
)
context
[
self
.
group_name
].
append
(
local_norm_var
)
self
.
context
=
context
...
...
@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
group_norm_var
=
layers
.
sqrt
(
x
=
group_norm_var
)
group_norm_var
=
layers
.
cast
(
group_norm_var
,
'float32'
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
379d933a
...
...
@@ -980,7 +980,12 @@ def cos_sim(X, Y):
return
out
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
):
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
,
dropout_implementation
=
"downgrade_in_infer"
):
"""
Computes dropout.
...
...
@@ -1000,6 +1005,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns:
Variable: A tensor variable is the shape with `x`.
...
...
@@ -1029,7 +1049,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob'
:
dropout_prob
,
'is_test'
:
is_test
,
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
'seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout_implementation'
:
dropout_implementation
,
})
return
out
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
379d933a
...
...
@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self
.
check_output
()
class
TestDropoutOp6
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
),
'Mask'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
)
}
class
TestDropoutOp7
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
)).
astype
(
'float32'
)
}
class
TestDropoutOp8
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'fix_seed'
:
True
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDropoutOp9
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录