未验证 提交 379d933a 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #14036 from phlrain/add_dropout_att_new

Add dropout att new 1.1 merge
......@@ -28,3 +28,4 @@ third_party/
build_*
# clion workspace.
cmake-build-*
model_test
......@@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name'
paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None))
paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer'))
paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None))
paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None))
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/dropout_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
"will be dropped.")
.SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_prob)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * dropout_prob"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_prob )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string& type) {
PADDLE_ENFORCE(
type == "downgrade_in_infer" || type == "upscale_in_train",
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train");
});
AddComment(R"DOC(
Dropout Operator.
......@@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
......@@ -26,7 +27,8 @@ namespace operators {
template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed,
const float dropout_prob, const T* src,
T* mask_data, T* dst) {
T* mask_data, T* dst,
bool is_upscale_in_train) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<float> dist(0, 1);
......@@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed,
if (dist(rng) < dropout_prob) {
mask = static_cast<T>(0);
} else {
mask = static_cast<T>(1);
if (is_upscale_in_train) {
mask = static_cast<T>(1.0f / (1.0f - dropout_prob));
} else {
mask = static_cast<T>(1);
}
}
dest = s * mask;
mask_data[idx] = mask;
......@@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
......@@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data);
size, seed, dropout_prob, x_data, mask_data, y_data,
(dropout_implementation == "upscale_in_train"));
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
}
};
......@@ -99,6 +112,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(dropout_grad,
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
dropout_grad, ops::DropoutGradKernel<plat::CUDADeviceContext, float>,
ops::DropoutGradKernel<plat::CUDADeviceContext, double>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <random>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y_data = y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
auto dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
......@@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
engine.seed(seed);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
mask_data[i] = 1;
y_data[i] = x_data[i];
if (dropout_implementation == "upscale_in_train") {
mask_data[i] = 1.0f / static_cast<T>(1.0f - dropout_prob);
y_data[i] = x_data[i] / static_cast<T>(1.0f - dropout_prob);
} else {
mask_data[i] = 1;
y_data[i] = x_data[i];
}
}
}
} else {
......@@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
Y.device(place) = X * (1.0f - dropout_prob);
if (dropout_implementation == "upscale_in_train") {
Y.device(place) = X;
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
}
};
......
......@@ -76,6 +76,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<double>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>);
ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>);
......@@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker,
REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
ops::Transpose2GradMaker);
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
REGISTER_OP_CPU_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>);
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -16,15 +16,18 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
transpose,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
transpose, ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
transpose_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
transpose2,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>);
ops::TransposeKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
transpose2_grad,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TransposeGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
)
square = grad * grad
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64')
local_norm_var = layers.reduce_sum(input=square)
context[self.group_name].append(local_norm_var)
self.context = context
......@@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var)
group_norm_var = layers.cast(group_norm_var, 'float32')
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(
x=clip_var,
......
......@@ -980,7 +980,12 @@ def cos_sim(X, Y):
return out
def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
def dropout(x,
dropout_prob,
is_test=False,
seed=None,
name=None,
dropout_implementation="downgrade_in_infer"):
"""
Computes dropout.
......@@ -1000,6 +1005,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
units will be dropped. DO NOT use a fixed seed in training.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train']
1. downgrade_in_infer(default), downgrade the outcome at inference
train: out = input * mask
inference: out = input * dropout_prob
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
2. upscale_in_train, upscale the outcome at training time
train: out = input * mask / ( 1.0 - dropout_prob )
inference: out = input
(make is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob)
dropout op can be removed from the program.
the program will be efficient
Returns:
Variable: A tensor variable is the shape with `x`.
......@@ -1029,7 +1049,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
'dropout_prob': dropout_prob,
'is_test': is_test,
'fix_seed': seed is not None,
'seed': seed if seed is not None else 0
'seed': seed if seed is not None else 0,
'dropout_implementation': dropout_implementation,
})
return out
......
......@@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest):
self.check_output()
class TestDropoutOp6(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 1.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('float32')
}
class TestDropoutOp7(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {
'dropout_prob': 0.0,
'fix_seed': True,
'is_test': False,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {
'Out': self.inputs['X'],
'Mask': np.ones((32, 64, 2)).astype('float32')
}
class TestDropoutOp8(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {
'dropout_prob': 0.35,
'fix_seed': True,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestDropoutOp9(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {
'dropout_prob': 0.75,
'is_test': True,
'dropout_implementation': 'upscale_in_train'
}
self.outputs = {'Out': self.inputs['X']}
def test_check_output(self):
self.check_output()
class TestFP16DropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册