diff --git a/.gitignore b/.gitignore index 90138f996cf9cacc3c1cbff0cf2600eefca3f305..fa0c8882606b76ac71b43dcde7e1df6770c46c31 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ third_party/ build_* # clion workspace. cmake-build-* +model_test diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 19ef23cdfa90912ff6fbd050a685d10861d909d2..9a4e3caabae9a97d97b37e0ab5403ceb7fe28d03 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name' paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) +paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)) diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 07322e720f26213ea777be3cd22f2fead28507f0..3c28ef30922e6d6ba09b96282619eef15867631e 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" +#include namespace paddle { namespace operators { @@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { "will be dropped.") .SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * dropout_prob" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string& type) { + PADDLE_ENFORCE( + type == "downgrade_in_infer" || type == "upscale_in_train", + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train"); + }); AddComment(R"DOC( Dropout Operator. @@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel, + ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, - ops::DropoutGradKernel); + ops::DropoutGradKernel, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 1dd66e0280c46c0624ff70e822cb6fa6f06b7aa9..e011f47e086183a4ef3a3373c17acd6c21b6cf7e 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" @@ -26,7 +27,8 @@ namespace operators { template __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, - T* mask_data, T* dst) { + T* mask_data, T* dst, + bool is_upscale_in_train) { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); @@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed, if (dist(rng) < dropout_prob) { mask = static_cast(0); } else { - mask = static_cast(1); + if (is_upscale_in_train) { + mask = static_cast(1.0f / (1.0f - dropout_prob)); + } else { + mask = static_cast(1); + } } dest = s * mask; mask_data[idx] = mask; @@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel { y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); @@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel { int grid = (x->numel() + threads - 1) / threads; RandomGenerator< T><<>>( - size, seed, dropout_prob, x_data, mask_data, y_data); + size, seed, dropout_prob, x_data, mask_data, y_data, + (dropout_implementation == "upscale_in_train")); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); - Y.device(place) = X * static_cast(1.0f - dropout_prob); + if (dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } } } }; @@ -99,6 +112,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL(dropout_grad, - ops::DropoutGradKernel); + ops::GPUDropoutKernel, + ops::GPUDropoutKernel); +REGISTER_OP_CUDA_KERNEL( + dropout_grad, ops::DropoutGradKernel, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 0628b4b826d2730a8e3fb4842e4ae550b8c00569..6c629b7b6d255828023ed25680675ca104a33e12 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel { auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); @@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel { engine.seed(seed); std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); for (size_t i = 0; i < size; ++i) { if (dist(engine) < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { - mask_data[i] = 1; - y_data[i] = x_data[i]; + if (dropout_implementation == "upscale_in_train") { + mask_data[i] = 1.0f / static_cast(1.0f - dropout_prob); + y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } } } } else { @@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - Y.device(place) = X * (1.0f - dropout_prob); + if (dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } } } }; diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e999621b10799b5163f326bc4b66a437e6..f6e241af0634650f4a32be6a4547617f8ec3ee60 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -76,6 +76,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, + ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 6a9fc6611a8f8eaa6749aefac0673ccabaebbcfe..bbd71db6062107f6ba40343c84d942b54b3958e6 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OP_CPU_KERNEL( - transpose, ops::TransposeKernel); + transpose, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CPU_KERNEL( transpose_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, ops::Transpose2GradMaker); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OP_CPU_KERNEL( - transpose2, - ops::TransposeKernel); + transpose2, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CPU_KERNEL( transpose2_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/paddle/fluid/operators/transpose_op.cu.cc b/paddle/fluid/operators/transpose_op.cu.cc index c1b5a8b31be243fab3af06a18c8e51986c953700..b4025350fa9f3610bde43eee91cd059f3063813f 100644 --- a/paddle/fluid/operators/transpose_op.cu.cc +++ b/paddle/fluid/operators/transpose_op.cu.cc @@ -16,15 +16,18 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - transpose, - ops::TransposeKernel); + transpose, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OP_CUDA_KERNEL( transpose2, - ops::TransposeKernel); + ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose2_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index d329db0d28998b30a81e3e6c21634cb6ff945b26..1738afe93e99f1de28bec2fb23be8e1a309d9288 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -272,7 +272,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ) square = grad * grad - local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64') + local_norm_var = layers.reduce_sum(input=square) context[self.group_name].append(local_norm_var) self.context = context @@ -282,7 +282,6 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): if group_scale_name not in self.context: group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sqrt(x=group_norm_var) - group_norm_var = layers.cast(group_norm_var, 'float32') clip_var = self.context[self.group_name + "_clip"] group_scale_var = layers.elementwise_div( x=clip_var, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cca618b9ad2fef9bf4870f0f94d17fbc529fb83c..98f4539feb65624e7cc1566c733b750d1324dee4 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -980,7 +980,12 @@ def cos_sim(X, Y): return out -def dropout(x, dropout_prob, is_test=False, seed=None, name=None): +def dropout(x, + dropout_prob, + is_test=False, + seed=None, + name=None, + dropout_implementation="downgrade_in_infer"): """ Computes dropout. @@ -1000,6 +1005,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): units will be dropped. DO NOT use a fixed seed in training. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. + dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train'] + 1. downgrade_in_infer(default), downgrade the outcome at inference + train: out = input * mask + inference: out = input * dropout_prob + (make is a tensor same shape with input, value is 0 or 1 + ratio of 0 is dropout_prob) + 2. upscale_in_train, upscale the outcome at training time + train: out = input * mask / ( 1.0 - dropout_prob ) + inference: out = input + (make is a tensor same shape with input, value is 0 or 1 + ratio of 0 is dropout_prob) + dropout op can be removed from the program. + the program will be efficient + + Returns: Variable: A tensor variable is the shape with `x`. @@ -1029,7 +1049,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): 'dropout_prob': dropout_prob, 'is_test': is_test, 'fix_seed': seed is not None, - 'seed': seed if seed is not None else 0 + 'seed': seed if seed is not None else 0, + 'dropout_implementation': dropout_implementation, }) return out diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 0296bc2af4e0b79478c34b4cceab32b5a8a50f2f..be3c5f3b9558ec522803ed9a5acedea75cda6ccc 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest): self.check_output() +class TestDropoutOp6(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = { + 'dropout_prob': 1.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': np.zeros((32, 64)).astype('float32'), + 'Mask': np.zeros((32, 64)).astype('float32') + } + + +class TestDropoutOp7(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64, 2)).astype('float32') + } + + +class TestDropoutOp8(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.35, + 'fix_seed': True, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp9(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.75, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + class TestFP16DropoutOp(OpTest): def setUp(self): self.op_type = "dropout"