Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
379d933a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
379d933a
编写于
10月 26, 2018
作者:
H
Hongyu Liu
提交者:
GitHub
10月 26, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14036 from phlrain/add_dropout_att_new
Add dropout att new 1.1 merge
上级
18be7256
a4ad286e
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
174 addition
and
28 deletion
+174
-28
.gitignore
.gitignore
+1
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-1
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+28
-2
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+22
-7
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+16
-3
paddle/fluid/operators/softmax_cudnn_op.cu.cc
paddle/fluid/operators/softmax_cudnn_op.cu.cc
+3
-1
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+8
-5
paddle/fluid/operators/transpose_op.cu.cc
paddle/fluid/operators/transpose_op.cu.cc
+8
-5
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+1
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+23
-2
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+63
-0
未找到文件。
.gitignore
浏览文件 @
379d933a
...
...
@@ -28,3 +28,4 @@ third_party/
build_*
# clion workspace.
cmake-build-*
model_test
paddle/fluid/API.spec
浏览文件 @
379d933a
...
...
@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
], varargs=None, keywords=None, defaults=(False, None, None
))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'
, 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'
))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
...
...
paddle/fluid/operators/dropout_op.cc
浏览文件 @
379d933a
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace
paddle
{
namespace
operators
{
...
...
@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"seed"
,
"Dropout random seed."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"dropout_implementation"
,
"[
\"
downgrade_in_infer
\"
|
\"
upscale_in_train
\"
]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient"
)
.
SetDefault
(
"downgrade_in_infer"
)
.
AddCustomChecker
([](
const
std
::
string
&
type
)
{
PADDLE_ENFORCE
(
type
==
"downgrade_in_infer"
||
type
==
"upscale_in_train"
,
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"
);
});
AddComment
(
R"DOC(
Dropout Operator.
...
...
@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.cu
浏览文件 @
379d933a
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -26,7 +27,8 @@ namespace operators {
template
<
typename
T
>
__global__
void
RandomGenerator
(
const
size_t
n
,
const
int
seed
,
const
float
dropout_prob
,
const
T
*
src
,
T
*
mask_data
,
T
*
dst
)
{
T
*
mask_data
,
T
*
dst
,
bool
is_upscale_in_train
)
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed
);
thrust
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
...
...
@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if
(
dist
(
rng
)
<
dropout_prob
)
{
mask
=
static_cast
<
T
>
(
0
);
}
else
{
mask
=
static_cast
<
T
>
(
1
);
if
(
is_upscale_in_train
)
{
mask
=
static_cast
<
T
>
(
1.0
f
/
(
1.0
f
-
dropout_prob
));
}
else
{
mask
=
static_cast
<
T
>
(
1
);
}
}
dest
=
s
*
mask
;
mask_data
[
idx
]
=
mask
;
...
...
@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
...
...
@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int
grid
=
(
x
->
numel
()
+
threads
-
1
)
/
threads
;
RandomGenerator
<
T
><<<
grid
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
);
size
,
seed
,
dropout_prob
,
x_data
,
mask_data
,
y_data
,
(
dropout_implementation
==
"upscale_in_train"
));
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.h
浏览文件 @
379d933a
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine
.
seed
(
seed
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
size_t
size
=
framework
::
product
(
mask
->
dims
());
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
if
(
dropout_implementation
==
"upscale_in_train"
)
{
mask_data
[
i
]
=
1.0
f
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
mask_data
[
i
]
=
1
;
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
...
...
@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Y
.
device
(
place
)
=
X
*
(
1.0
f
-
dropout_prob
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
Y
.
device
(
place
)
=
X
;
}
else
{
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
...
...
paddle/fluid/operators/softmax_cudnn_op.cu.cc
浏览文件 @
379d933a
...
...
@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_KERNEL
(
softmax
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxCUDNNKernel
<
float
>
,
ops
::
SoftmaxCUDNNKernel
<
double
>
,
ops
::
SoftmaxCUDNNKernel
<
plat
::
float16
>
);
REGISTER_OP_KERNEL
(
softmax_grad
,
CUDNN
,
plat
::
CUDAPlace
,
ops
::
SoftmaxGradCUDNNKernel
<
float
>
);
ops
::
SoftmaxGradCUDNNKernel
<
float
>
,
ops
::
SoftmaxGradCUDNNKernel
<
double
>
);
paddle/fluid/operators/transpose_op.cc
浏览文件 @
379d933a
...
...
@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR
(
transpose_grad
,
ops
::
TransposeOpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OPERATOR
(
transpose2
,
ops
::
Transpose2Op
,
ops
::
Transpose2OpMaker
,
ops
::
Transpose2GradMaker
);
REGISTER_OPERATOR
(
transpose2_grad
,
ops
::
Transpose2OpGrad
);
REGISTER_OP_CPU_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/transpose_op.cu.cc
浏览文件 @
379d933a
...
...
@@ -16,15 +16,18 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
transpose
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
transpose2_grad
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TransposeGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
python/paddle/fluid/clip.py
浏览文件 @
379d933a
...
...
@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
)
square
=
grad
*
grad
local_norm_var
=
layers
.
cast
(
layers
.
reduce_sum
(
input
=
square
),
'float64'
)
local_norm_var
=
layers
.
reduce_sum
(
input
=
square
)
context
[
self
.
group_name
].
append
(
local_norm_var
)
self
.
context
=
context
...
...
@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if
group_scale_name
not
in
self
.
context
:
group_norm_var
=
layers
.
sums
(
input
=
self
.
context
[
self
.
group_name
])
group_norm_var
=
layers
.
sqrt
(
x
=
group_norm_var
)
group_norm_var
=
layers
.
cast
(
group_norm_var
,
'float32'
)
clip_var
=
self
.
context
[
self
.
group_name
+
"_clip"
]
group_scale_var
=
layers
.
elementwise_div
(
x
=
clip_var
,
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
379d933a
...
...
@@ -980,7 +980,12 @@ def cos_sim(X, Y):
return
out
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
):
def
dropout
(
x
,
dropout_prob
,
is_test
=
False
,
seed
=
None
,
name
=
None
,
dropout_implementation
=
"downgrade_in_infer"
):
"""
Computes dropout.
...
...
@@ -1000,6 +1005,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns:
Variable: A tensor variable is the shape with `x`.
...
...
@@ -1029,7 +1049,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob'
:
dropout_prob
,
'is_test'
:
is_test
,
'fix_seed'
:
seed
is
not
None
,
'seed'
:
seed
if
seed
is
not
None
else
0
'seed'
:
seed
if
seed
is
not
None
else
0
,
'dropout_implementation'
:
dropout_implementation
,
})
return
out
...
...
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
379d933a
...
...
@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self
.
check_output
()
class
TestDropoutOp6
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
1.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
),
'Mask'
:
np
.
zeros
((
32
,
64
)).
astype
(
'float32'
)
}
class
TestDropoutOp7
(
TestDropoutOp
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
2
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.0
,
'fix_seed'
:
True
,
'is_test'
:
False
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
],
'Mask'
:
np
.
ones
((
32
,
64
,
2
)).
astype
(
'float32'
)
}
class
TestDropoutOp8
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.35
,
'fix_seed'
:
True
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDropoutOp9
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
32
,
64
,
3
)).
astype
(
"float32"
)}
self
.
attrs
=
{
'dropout_prob'
:
0.75
,
'is_test'
:
True
,
'dropout_implementation'
:
'upscale_in_train'
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFP16DropoutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"dropout"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录