Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
99fc1b08
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
99fc1b08
编写于
3月 10, 2022
作者:
H
hong
提交者:
GitHub
3月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move dropout to phi (#40148)
* move dropout to phi; test=develop * fix xpu, npu compile error; test=develop
上级
843f6da0
变更
41
隐藏空白更改
内联
并排
Showing
41 changed file
with
481 addition
and
303 deletion
+481
-303
paddle/fluid/inference/tensorrt/convert/dropout_op.cc
paddle/fluid/inference/tensorrt/convert/dropout_op.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc
paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc
+1
-1
paddle/fluid/operators/assign_op_npu_test.cc
paddle/fluid/operators/assign_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_allgather_op_npu_test.cc
paddle/fluid/operators/collective/c_allgather_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc
...fluid/operators/collective/c_allreduce_max_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc
...fluid/operators/collective/c_allreduce_sum_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc
paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc
...le/fluid/operators/collective/c_reduce_sum_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc
...fluid/operators/collective/c_reducescatter_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc
...id/operators/collective/c_sync_comm_stream_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/checknumeric_npu_test.cc
paddle/fluid/operators/collective/checknumeric_npu_test.cc
+0
-1
paddle/fluid/operators/collective/recv_v2_op_npu_test.cc
paddle/fluid/operators/collective/recv_v2_op_npu_test.cc
+0
-1
paddle/fluid/operators/collective/send_v2_op_npu_test.cc
paddle/fluid/operators/collective/send_v2_op_npu_test.cc
+0
-1
paddle/fluid/operators/dropout_impl.cu.h
paddle/fluid/operators/dropout_impl.cu.h
+15
-12
paddle/fluid/operators/dropout_impl_util.h
paddle/fluid/operators/dropout_impl_util.h
+1
-1
paddle/fluid/operators/dropout_op.cc
paddle/fluid/operators/dropout_op.cc
+1
-12
paddle/fluid/operators/dropout_op.cu
paddle/fluid/operators/dropout_op.cu
+0
-94
paddle/fluid/operators/dropout_op.h
paddle/fluid/operators/dropout_op.h
+0
-151
paddle/fluid/operators/dropout_op_npu.cc
paddle/fluid/operators/dropout_op_npu.cc
+1
-1
paddle/fluid/operators/dropout_op_test.cc
paddle/fluid/operators/dropout_op_test.cc
+1
-2
paddle/fluid/operators/dropout_op_xpu.cc
paddle/fluid/operators/dropout_op_xpu.cc
+3
-1
paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc
...le/fluid/operators/elementwise/elementwise_op_npu_test.cc
+0
-1
paddle/fluid/operators/expand_op_npu_test.cc
paddle/fluid/operators/expand_op_npu_test.cc
+0
-1
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+6
-5
paddle/fluid/operators/fused/fused_dropout_test.h
paddle/fluid/operators/fused/fused_dropout_test.h
+1
-1
paddle/fluid/operators/gelu_op_npu_test.cc
paddle/fluid/operators/gelu_op_npu_test.cc
+0
-1
paddle/fluid/operators/increment_op_npu_test.cc
paddle/fluid/operators/increment_op_npu_test.cc
+0
-1
paddle/fluid/operators/range_op_npu_test.cc
paddle/fluid/operators/range_op_npu_test.cc
+0
-1
paddle/fluid/operators/rnn_op.h
paddle/fluid/operators/rnn_op.h
+9
-1
paddle/fluid/operators/softmax_op_npu_test.cc
paddle/fluid/operators/softmax_op_npu_test.cc
+0
-1
paddle/fluid/operators/squeeze_op_npu_test.cc
paddle/fluid/operators/squeeze_op_npu_test.cc
+0
-1
paddle/fluid/operators/transpose_op_npu_test.cc
paddle/fluid/operators/transpose_op_npu_test.cc
+0
-1
paddle/fluid/operators/unsqueeze_op_npu_test.cc
paddle/fluid/operators/unsqueeze_op_npu_test.cc
+0
-1
paddle/phi/kernels/cpu/dropout_grad_kernel.cc
paddle/phi/kernels/cpu/dropout_grad_kernel.cc
+67
-0
paddle/phi/kernels/cpu/dropout_kernel.cc
paddle/phi/kernels/cpu/dropout_kernel.cc
+104
-0
paddle/phi/kernels/dropout_grad_kernel.h
paddle/phi/kernels/dropout_grad_kernel.h
+31
-0
paddle/phi/kernels/dropout_kernel.h
paddle/phi/kernels/dropout_kernel.h
+34
-0
paddle/phi/kernels/gpu/dropout_grad_kernel.cu
paddle/phi/kernels/gpu/dropout_grad_kernel.cu
+46
-0
paddle/phi/kernels/gpu/dropout_kernel.cu
paddle/phi/kernels/gpu/dropout_kernel.cu
+61
-0
paddle/phi/ops/compat/dropout_sig.cc
paddle/phi/ops/compat/dropout_sig.cc
+38
-0
python/paddle/fluid/tests/unittests/test_dropout_op.py
python/paddle/fluid/tests/unittests/test_dropout_op.py
+60
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/dropout_op.cc
浏览文件 @
99fc1b08
...
...
@@ -89,5 +89,5 @@ class DropoutOpConverter : public OpConverter {
}
// namespace inference
}
// namespace paddle
USE_OP
(
dropout
);
USE_OP
_ITSELF
(
dropout
);
REGISTER_TRT_OP_CONVERTER
(
dropout
,
DropoutOpConverter
);
paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc
浏览文件 @
99fc1b08
...
...
@@ -57,4 +57,4 @@ TEST(DropoutOpConverter, main) {
}
// namespace inference
}
// namespace paddle
USE_OP
(
dropout
);
USE_OP
_ITSELF
(
dropout
);
paddle/fluid/operators/assign_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_allgather_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/checknumeric_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -27,7 +27,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/recv_v2_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/collective/send_v2_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/dropout_impl.cu.h
浏览文件 @
99fc1b08
...
...
@@ -32,10 +32,9 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/
device/gpu/gpu_launch_config
.h"
#include "paddle/phi/
kernels/funcs/aligned_vector
.h"
#include "paddle/fluid/platform/
aligned_vector
.h"
#include "paddle/phi/
backends/gpu/gpu_launch_config
.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace
paddle
{
...
...
@@ -177,12 +176,13 @@ __global__ void DropoutGradCUDAKernel(
}
template
<
typename
T
>
void
DropoutFwGPUKernelDriver
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
bool
is_test
,
void
DropoutFwGPUKernelDriver
(
const
phi
::
GPUContext
&
dev_ctx
,
bool
is_test
,
const
std
::
string
dropout_implementation
,
float
dropout_prob
,
bool
upscale_in_train
,
bool
is_fix_seed
,
int
seed_val
,
const
Tensor
&
x
,
const
Tensor
*
seed
,
Tensor
*
mask
,
Tensor
*
y
)
{
bool
is_fix_seed
,
int
seed_val
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
*
seed
,
framework
::
Tensor
*
mask
,
framework
::
Tensor
*
y
)
{
auto
&
place
=
*
dev_ctx
.
eigen_device
();
int64_t
x_numel
=
x
.
numel
();
auto
stream
=
dev_ctx
.
stream
();
...
...
@@ -220,7 +220,8 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
// VectorizedRandomGenerator use curand_uniform4, so we only support
// vec_size is 4;
int
vec_size
=
(
phi
::
GetVectorizedSize
<
T
>
(
x_data
)
==
4
)
?
4
:
1
;
auto
gpu_config
=
GetGpuLaunchConfig1D
(
dev_ctx
,
x_numel
,
vec_size
);
auto
gpu_config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x_numel
,
vec_size
);
auto
offset
=
((
x_numel
-
1
)
/
(
gpu_config
.
GetThreadNum
()
*
vec_size
)
+
1
)
*
vec_size
;
...
...
@@ -278,11 +279,13 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
}
template
<
typename
T
>
void
DropoutGradGPUKernelDriver
(
const
p
latform
::
CUDADevice
Context
&
dev_ctx
,
void
DropoutGradGPUKernelDriver
(
const
p
hi
::
GPU
Context
&
dev_ctx
,
const
std
::
string
dropout_implementation
,
float
dropout_prob
,
const
Tensor
&
grad_y
,
const
Tensor
&
mask
,
int64_t
size
,
Tensor
*
grad_x
,
bool
is_test
=
false
)
{
float
dropout_prob
,
const
framework
::
Tensor
&
grad_y
,
const
framework
::
Tensor
&
mask
,
int64_t
size
,
framework
::
Tensor
*
grad_x
,
bool
is_test
=
false
)
{
using
MT
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
auto
stream
=
dev_ctx
.
stream
();
MT
factor
;
...
...
paddle/fluid/operators/dropout_impl_util.h
浏览文件 @
99fc1b08
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
inline
void
GetSeedDataAndIncrement
(
const
p
latform
::
CUDADevice
Context
&
dev_ctx
,
inline
void
GetSeedDataAndIncrement
(
const
p
hi
::
GPU
Context
&
dev_ctx
,
const
framework
::
Tensor
*
seed
,
const
bool
is_fix_seed
,
const
int
seed_val
,
const
int
offset
,
uint64_t
*
seed_data
,
...
...
paddle/fluid/operators/dropout_op.cc
浏览文件 @
99fc1b08
...
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -177,14 +177,3 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
ops
::
DropoutGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
DropoutGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
dropout_grad
,
ops
::
DropoutOpGrad
);
REGISTER_OP_CPU_KERNEL
(
dropout
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
CPUDropoutKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
dropout_grad
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
DropoutGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
bfloat16
>
);
paddle/fluid/operators/dropout_op.cu
已删除
100644 → 0
浏览文件 @
843f6da0
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template
<
typename
Place
,
typename
T
>
class
GPUDropoutKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
seed
=
context
.
HasInput
(
"Seed"
)
?
context
.
Input
<
Tensor
>
(
"Seed"
)
:
nullptr
;
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
bool
upscale_in_train
=
(
dropout_implementation
==
"upscale_in_train"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
&
dev_ctx
=
context
.
cuda_device_context
();
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
mask
->
mutable_data
<
uint8_t
>
(
context
.
GetPlace
());
bool
is_fix_seed
=
context
.
Attr
<
bool
>
(
"fix_seed"
);
int
seed_val
=
context
.
Attr
<
int
>
(
"seed"
);
DropoutFwGPUKernelDriver
<
T
>
(
dev_ctx
,
is_test
,
dropout_implementation
,
dropout_prob
,
upscale_in_train
,
is_fix_seed
,
seed_val
,
*
x
,
seed
,
mask
,
y
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
GPUDropoutGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
grad_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grad_y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
size
=
grad_x
->
numel
();
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
DropoutGradGPUKernelDriver
<
T
>
(
dev_ctx
,
dropout_implementation
,
dropout_prob
,
*
grad_y
,
*
mask
,
size
,
grad_x
,
is_test
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
dropout
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
GPUDropoutKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
dropout_grad
,
ops
::
GPUDropoutGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
GPUDropoutGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
GPUDropoutGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
bfloat16
>
,
ops
::
GPUDropoutGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/dropout_op.h
已删除
100644 → 0
浏览文件 @
843f6da0
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstring>
#include <random>
#include <string>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
DeviceContext
,
typename
T
>
class
CPUDropoutKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
seed
=
context
.
HasInput
(
"Seed"
)
?
context
.
Input
<
Tensor
>
(
"Seed"
)
:
nullptr
;
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
const
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
bool
upscale_in_train
=
(
dropout_implementation
==
"upscale_in_train"
);
if
(
!
context
.
Attr
<
bool
>
(
"is_test"
))
{
auto
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
auto
*
mask_data
=
mask
->
mutable_data
<
uint8_t
>
(
context
.
GetPlace
());
size_t
size
=
phi
::
product
(
mask
->
dims
());
// Special case when dropout_prob is 1.0
if
(
dropout_prob
==
1.0
f
)
{
std
::
memset
(
y_data
,
0
,
size
*
sizeof
(
*
y_data
));
// NOLINT
std
::
memset
(
mask_data
,
0
,
size
*
sizeof
(
*
mask_data
));
// NOLINT
return
;
}
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
int
seed_data
=
0
;
if
(
seed
)
{
seed_data
=
*
(
seed
->
data
<
int
>
());
}
else
{
seed_data
=
context
.
Attr
<
bool
>
(
"fix_seed"
)
?
context
.
Attr
<
int
>
(
"seed"
)
:
0
;
}
auto
engine
=
framework
::
GetCPURandomEngine
(
seed_data
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
*
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
if
(
upscale_in_train
)
{
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
if
(
upscale_in_train
)
{
const
auto
*
X_data
=
x
->
data
<
T
>
();
auto
*
Y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
int
i
=
0
;
i
<
x
->
numel
();
i
++
)
{
Y_data
[
i
]
=
X_data
[
i
];
}
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
*
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DropoutGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
grad_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grad_y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dX
=
EigenVector
<
T
>::
Flatten
(
*
grad_x
);
auto
dY
=
EigenVector
<
T
>::
Flatten
(
*
grad_y
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
dropout_implementation
=
context
.
Attr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
context
.
Attr
<
bool
>
(
"is_test"
)
==
true
)
{
if
(
dropout_implementation
==
"upscale_in_train"
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
1
)
*
dY
;
}
else
{
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
dX
.
device
(
place
)
=
dY
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
else
{
auto
M
=
EigenVector
<
uint8_t
>::
Flatten
(
*
mask
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
float
dropout_prob
=
context
.
Attr
<
float
>
(
"dropout_prob"
);
if
(
dropout_prob
==
1.0
f
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
0
)
*
dY
;
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
()
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
();
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/dropout_op_npu.cc
浏览文件 @
99fc1b08
...
...
@@ -15,8 +15,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/core/ddim.h"
...
...
paddle/fluid/operators/dropout_op_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,14 +24,13 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
f
=
paddle
::
framework
;
namespace
p
=
paddle
::
platform
;
USE_OP
(
dropout
);
USE_OP
_ITSELF
(
dropout
);
void
Compare
(
f
::
Scope
*
scope
,
const
p
::
DeviceContext
&
ctx
)
{
// init
...
...
paddle/fluid/operators/dropout_op_xpu.cc
浏览文件 @
99fc1b08
...
...
@@ -8,15 +8,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace
paddle
{
namespace
operators
{
#ifdef PADDLE_WITH_XPU
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
DropoutXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
using
XPUTyp
=
typename
XPUTypeTrait
<
T
>::
Type
;
...
...
paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/expand_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/fused/fmha_ref.h
浏览文件 @
99fc1b08
...
...
@@ -140,9 +140,9 @@ class FMHARef {
if
(
dropout_param_
.
dropout_prob_
)
{
DropoutFwGPUKernelDriver
<
T
>
(
dev_ctx_
,
dropout_param_
.
is_test_
,
static_cast
<
const
std
::
string
>
(
dropout_param_
.
dropout_implementation_
),
static_cast
<
const
phi
::
GPUContext
&>
(
dev_ctx_
)
,
dropout_param_
.
is_test_
,
static_cast
<
const
std
::
string
>
(
dropout_param_
.
dropout_implementation_
),
dropout_param_
.
dropout_prob_
,
dropout_param_
.
is_upscale_in_train_
,
dropout_param_
.
is_fix_seed_
,
dropout_param_
.
seed_val_
,
static_cast
<
const
Tensor
&>
(
*
softmax_out_tensor
),
dropout_param_
.
seed_
,
...
...
@@ -242,8 +242,9 @@ class FMHARef {
// dropout bw
if
(
dropout_param_
.
dropout_prob_
)
{
DropoutGradGPUKernelDriver
<
T
>
(
dev_ctx_
,
static_cast
<
const
std
::
string
>
(
dropout_param_
.
dropout_implementation_
),
static_cast
<
const
phi
::
GPUContext
&>
(
dev_ctx_
),
static_cast
<
const
std
::
string
>
(
dropout_param_
.
dropout_implementation_
),
dropout_param_
.
dropout_prob_
,
static_cast
<
const
Tensor
&>
(
*
dropout_out_grad_tensor
),
dropout_mask_out_tensor
,
softmax_out_grad_tensor
->
numel
(),
...
...
paddle/fluid/operators/fused/fused_dropout_test.h
浏览文件 @
99fc1b08
...
...
@@ -31,7 +31,7 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
memory
=
paddle
::
memory
;
USE_OP
(
dropout
);
USE_OP
_ITSELF
(
dropout
);
USE_OP
(
layer_norm
);
template
<
typename
T
>
...
...
paddle/fluid/operators/gelu_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/increment_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/range_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/rnn_op.h
浏览文件 @
99fc1b08
...
...
@@ -16,9 +16,9 @@ limitations under the License. */
#include <type_traits>
#include <vector>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/unique_op.h"
...
...
@@ -36,6 +36,14 @@ using LoDTensor = framework::LoDTensor;
using
Tensor
=
framework
::
Tensor
;
using
TensorList
=
std
::
vector
<
framework
::
Tensor
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \
const std::string& mode = ctx.Attr<std::string>("mode"); \
...
...
paddle/fluid/operators/softmax_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/squeeze_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/transpose_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/fluid/operators/unsqueeze_op_npu_test.cc
浏览文件 @
99fc1b08
...
...
@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h"
...
...
paddle/phi/kernels/cpu/dropout_grad_kernel.cc
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dropout_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DropoutGradRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
mask
,
const
DenseTensor
&
out_grad
,
float
p
,
bool
is_test
,
const
std
::
string
&
mode
,
DenseTensor
*
x_grad
)
{
auto
*
grad_x
=
x_grad
;
auto
*
grad_y
=
&
out_grad
;
grad_x
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
dX
=
EigenVector
<
T
>::
Flatten
(
*
grad_x
);
auto
dY
=
EigenVector
<
T
>::
Flatten
(
*
grad_y
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
&
dropout_implementation
=
mode
;
if
(
is_test
==
true
)
{
if
(
dropout_implementation
==
"upscale_in_train"
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
1
)
*
dY
;
}
else
{
dX
.
device
(
place
)
=
dY
*
static_cast
<
T
>
(
1.0
f
-
p
);
}
}
else
{
auto
M
=
EigenVector
<
uint8_t
>::
Flatten
(
mask
);
if
(
dropout_implementation
==
"upscale_in_train"
)
{
if
(
p
==
1.0
f
)
{
dX
.
device
(
place
)
=
static_cast
<
T
>
(
0
)
*
dY
;
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
()
/
static_cast
<
T
>
(
1.0
f
-
p
);
}
}
else
{
dX
.
device
(
place
)
=
dY
*
M
.
cast
<
T
>
();
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
dropout_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
DropoutGradRawKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/cpu/dropout_kernel.cc
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dropout_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DropoutRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
paddle
::
optional
<
const
DenseTensor
&>
seed_tensor
,
float
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
DenseTensor
*
out
,
DenseTensor
*
mask
)
{
auto
*
y
=
out
;
const
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
float
dropout_prob
=
p
;
auto
&
dropout_implementation
=
mode
;
bool
upscale_in_train
=
(
dropout_implementation
==
"upscale_in_train"
);
if
(
!
is_test
)
{
auto
*
mask_data
=
mask
->
mutable_data
<
uint8_t
>
(
dev_ctx
.
GetPlace
());
size_t
size
=
phi
::
product
(
mask
->
dims
());
// Special case when dropout_prob is 1.0
if
(
dropout_prob
==
1.0
f
)
{
std
::
memset
(
y_data
,
0
,
size
*
sizeof
(
*
y_data
));
// NOLINT
std
::
memset
(
mask_data
,
0
,
size
*
sizeof
(
*
mask_data
));
// NOLINT
return
;
}
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
int
seed_data
=
0
;
if
(
seed_tensor
.
get_ptr
()
!=
nullptr
)
{
seed_data
=
*
(
seed_tensor
->
data
<
int
>
());
}
else
{
seed_data
=
fix_seed
?
seed
:
0
;
}
auto
engine
=
paddle
::
framework
::
GetCPURandomEngine
(
seed_data
);
std
::
uniform_real_distribution
<
float
>
dist
(
0
,
1
);
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
dist
(
*
engine
)
<
dropout_prob
)
{
mask_data
[
i
]
=
0
;
y_data
[
i
]
=
0
;
}
else
{
mask_data
[
i
]
=
1
;
if
(
upscale_in_train
)
{
y_data
[
i
]
=
x_data
[
i
]
/
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
else
{
y_data
[
i
]
=
x_data
[
i
];
}
}
}
}
else
{
if
(
upscale_in_train
)
{
const
auto
*
X_data
=
x
.
data
<
T
>
();
auto
*
Y_data
=
y
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
int
i
=
0
;
i
<
x
.
numel
();
i
++
)
{
Y_data
[
i
]
=
X_data
[
i
];
}
}
else
{
auto
X
=
EigenMatrix
<
T
>::
Reshape
(
x
,
1
);
auto
Y
=
EigenMatrix
<
T
>::
Reshape
(
*
y
,
1
);
auto
&
place
=
*
dev_ctx
.
eigen_device
();
Y
.
device
(
place
)
=
X
*
static_cast
<
T
>
(
1.0
f
-
dropout_prob
);
}
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
dropout
,
CPU
,
ALL_LAYOUT
,
phi
::
DropoutRawKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/dropout_grad_kernel.h
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DropoutGradRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
mask
,
const
DenseTensor
&
out_grad
,
float
p
,
bool
is_test
,
const
std
::
string
&
mode
,
DenseTensor
*
x_grad
);
}
// namespace phi
paddle/phi/kernels/dropout_kernel.h
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DropoutRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
paddle
::
optional
<
const
DenseTensor
&>
seed_tensor
,
float
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
DenseTensor
*
out
,
DenseTensor
*
mask
);
}
// namespace phi
paddle/phi/kernels/gpu/dropout_grad_kernel.cu
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/phi/kernels/dropout_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DropoutGradRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
mask
,
const
DenseTensor
&
out_grad
,
float
p
,
bool
is_test
,
const
std
::
string
&
mode
,
DenseTensor
*
x_grad
)
{
x_grad
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
size
=
x_grad
->
numel
();
paddle
::
operators
::
DropoutGradGPUKernelDriver
<
T
>
(
dev_ctx
,
mode
,
p
,
out_grad
,
mask
,
size
,
x_grad
,
is_test
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
dropout_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
DropoutGradRawKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/dropout_kernel.cu
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/phi/kernels/dropout_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
DropoutRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
paddle
::
optional
<
const
DenseTensor
&>
seed_tensor
,
float
p
,
bool
is_test
,
const
std
::
string
&
mode
,
int
seed
,
bool
fix_seed
,
DenseTensor
*
out
,
DenseTensor
*
mask
)
{
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
float
dropout_prob
=
p
;
bool
upscale_in_train
=
(
mode
==
"upscale_in_train"
);
mask
->
mutable_data
<
uint8_t
>
(
dev_ctx
.
GetPlace
());
paddle
::
operators
::
DropoutFwGPUKernelDriver
<
T
>
(
dev_ctx
,
is_test
,
mode
,
dropout_prob
,
upscale_in_train
,
fix_seed
,
seed
,
x
,
seed_tensor
.
get_ptr
(),
mask
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
dropout
,
GPU
,
ALL_LAYOUT
,
phi
::
DropoutRawKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
,
phi
::
dtype
::
float16
)
{}
paddle/phi/ops/compat/dropout_sig.cc
0 → 100644
浏览文件 @
99fc1b08
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
DropoutOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"dropout"
,
{
"X"
,
"Seed"
},
{
"dropout_prob"
,
"is_test"
,
"dropout_implementation"
,
"seed"
,
"fix_seed"
},
{
"Out"
,
"Mask"
});
}
KernelSignature
DropoutGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"dropout_grad"
,
{
"Mask"
,
GradVarName
(
"Out"
)},
{
"dropout_prob"
,
"is_test"
,
"dropout_implementation"
},
{
GradVarName
(
"X"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
dropout
,
phi
::
DropoutOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
dropout_grad
,
phi
::
DropoutGradOpArgumentMapping
);
python/paddle/fluid/tests/unittests/test_dropout_op.py
浏览文件 @
99fc1b08
...
...
@@ -933,5 +933,65 @@ class TestDropoutWithDeterminateSeedGenerator(unittest.TestCase):
self
.
check_static_result
(
place
=
place
)
class
TestDropoutBackward
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
123
)
self
.
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
self
.
places
.
append
(
fluid
.
CUDAPlace
(
0
))
def
cal_grad_upscale_train
(
self
,
mask
,
prob
):
return
mask
.
astype
(
"float32"
)
/
(
1
-
prob
)
def
cal_grad_downscale_in_infer
(
self
,
mask
):
return
mask
.
astype
(
"float32"
)
def
test_backward_downscale_in_infer
(
self
):
for
place
in
self
.
places
:
with
fluid
.
dygraph
.
guard
(
place
):
input
=
paddle
.
uniform
([
40
,
40
],
dtype
=
"float32"
)
input
.
stop_gradient
=
False
out
,
mask
=
core
.
ops
.
dropout
(
input
,
'dropout_prob'
,
0.5
)
out
.
backward
()
self
.
assertTrue
(
np
.
array_equal
(
input
.
gradient
(
),
self
.
cal_grad_downscale_in_infer
(
mask
.
numpy
())))
def
test_backward_upscale_train
(
self
):
for
place
in
self
.
places
:
with
fluid
.
dygraph
.
guard
(
place
):
prob
=
0.5
input
=
paddle
.
uniform
([
40
,
40
],
dtype
=
"float32"
)
input
.
stop_gradient
=
False
out
,
mask
=
core
.
ops
.
dropout
(
input
,
'dropout_prob'
,
prob
,
"dropout_implementation"
,
"upscale_in_train"
)
out
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
input
.
gradient
(
),
self
.
cal_grad_upscale_train
(
mask
.
numpy
(),
prob
)))
def
test_backward_upscale_train_2
(
self
):
for
place
in
self
.
places
:
with
fluid
.
dygraph
.
guard
(
place
):
prob
=
0.3
input
=
paddle
.
uniform
([
40
,
40
],
dtype
=
"float32"
)
input
.
stop_gradient
=
False
out
,
mask
=
core
.
ops
.
dropout
(
input
,
'dropout_prob'
,
prob
,
"dropout_implementation"
,
"upscale_in_train"
)
out
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
input
.
gradient
(
),
self
.
cal_grad_upscale_train
(
mask
.
numpy
(),
prob
)))
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录