From a6e6bc45d63ab35c0fdb7450f1a6c9188a86c5be Mon Sep 17 00:00:00 2001 From: phlrain Date: Wed, 24 Oct 2018 09:03:50 +0000 Subject: [PATCH] modify dropout att; test=develop --- paddle/fluid/operators/dropout_op.cc | 33 ++++++++++++++----- paddle/fluid/operators/dropout_op.cu | 12 ++++--- paddle/fluid/operators/dropout_op.h | 8 +++-- python/paddle/fluid/layers/nn.py | 23 ++++++++----- .../fluid/tests/unittests/test_dropout_op.py | 8 ++--- 5 files changed, 55 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index b5023f391c8..3c28ef30922 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" +#include namespace paddle { namespace operators { @@ -57,15 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { "will be dropped.") .SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); - AddAttr("dropout_implementation", - "When it's True, In the training, after set some value" - "to 0 (probability is dropout_prob)," - "all the value will divide (1-dropout_prob)" - "By using this way, will do nothing in the inference program" - "The dropout op can be removed in the inference program." - "The inference program will be more efficient" - "When it's False, same as original") - .SetDefault(false); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * dropout_prob" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string& type) { + PADDLE_ENFORCE( + type == "downgrade_in_infer" || type == "upscale_in_train", + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train"); + }); AddComment(R"DOC( Dropout Operator. diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index a3d264ac136..e011f47e086 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" @@ -27,7 +28,7 @@ template __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, T* mask_data, T* dst, - bool dropout_implementation) { + bool is_upscale_in_train) { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); @@ -48,7 +49,7 @@ __global__ void RandomGenerator(const size_t n, const int seed, if (dist(rng) < dropout_prob) { mask = static_cast(0); } else { - if (dropout_implementation) { + if (is_upscale_in_train) { mask = static_cast(1.0f / (1.0f - dropout_prob)); } else { mask = static_cast(1); @@ -72,7 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel { y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); - auto dropout_implementation = context.Attr("dropout_implementation"); + auto dropout_implementation = + context.Attr("dropout_implementation"); auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); @@ -90,11 +92,11 @@ class GPUDropoutKernel : public framework::OpKernel { RandomGenerator< T><<>>( size, seed, dropout_prob, x_data, mask_data, y_data, - dropout_implementation); + (dropout_implementation == "upscale_in_train")); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); - if (dropout_implementation) { + if (dropout_implementation == "upscale_in_train") { Y.device(place) = X; } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index bc86aeb7f09..6c629b7b6d2 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -36,7 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel { auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); - auto dropout_implementation = context.Attr("dropout_implementation"); + auto dropout_implementation = + context.Attr("dropout_implementation"); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); @@ -57,7 +59,7 @@ class CPUDropoutKernel : public framework::OpKernel { mask_data[i] = 0; y_data[i] = 0; } else { - if (dropout_implementation) { + if (dropout_implementation == "upscale_in_train") { mask_data[i] = 1.0f / static_cast(1.0f - dropout_prob); y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); } else { @@ -71,7 +73,7 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - if (dropout_implementation) { + if (dropout_implementation == "upscale_in_train") { Y.device(place) = X; } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 83446e4bd16..98f4539feb6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -985,7 +985,7 @@ def dropout(x, is_test=False, seed=None, name=None, - dropout_implementation=False): + dropout_implementation="downgrade_in_infer"): """ Computes dropout. @@ -1005,13 +1005,20 @@ def dropout(x, units will be dropped. DO NOT use a fixed seed in training. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. - dropout_implementation(bool): A Flag indicating whether divide (1-dropout_prob). - When it's True, all the units will divide (1-dropout_prob) - after set some units to zero in the train program. - And do nothing in the inference program. - The dropout op can be removed in the inference program. - The inference program will be more efficient - When it's False, same as original + dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train'] + 1. downgrade_in_infer(default), downgrade the outcome at inference + train: out = input * mask + inference: out = input * dropout_prob + (make is a tensor same shape with input, value is 0 or 1 + ratio of 0 is dropout_prob) + 2. upscale_in_train, upscale the outcome at training time + train: out = input * mask / ( 1.0 - dropout_prob ) + inference: out = input + (make is a tensor same shape with input, value is 0 or 1 + ratio of 0 is dropout_prob) + dropout op can be removed from the program. + the program will be efficient + Returns: diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index ecfacb3277b..be3c5f3b955 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -93,7 +93,7 @@ class TestDropoutOp6(TestDropoutOp): 'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False, - 'div_prob_in_train': True + 'dropout_implementation': 'upscale_in_train' } self.outputs = { 'Out': np.zeros((32, 64)).astype('float32'), @@ -109,7 +109,7 @@ class TestDropoutOp7(TestDropoutOp): 'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False, - 'div_prob_in_train': True + 'dropout_implementation': 'upscale_in_train' } self.outputs = { 'Out': self.inputs['X'], @@ -125,7 +125,7 @@ class TestDropoutOp8(OpTest): 'dropout_prob': 0.35, 'fix_seed': True, 'is_test': True, - 'div_prob_in_train': True + 'dropout_implementation': 'upscale_in_train' } self.outputs = {'Out': self.inputs['X']} @@ -140,7 +140,7 @@ class TestDropoutOp9(OpTest): self.attrs = { 'dropout_prob': 0.75, 'is_test': True, - 'div_prob_in_train': True + 'dropout_implementation': 'upscale_in_train' } self.outputs = {'Out': self.inputs['X']} -- GitLab