diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index d306e20037a39d0170eb284dbd295f68d172e7c8..5848d9dad5e995eec51f54ae278d997e59195e1d 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -545,7 +545,7 @@ struct ZeroGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - dx.device(d) = static_cast(0) / out; + dx.device(d) = static_cast(0) * out; } static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index ffcf8a5800ea11ae98bfa321b36af87952a516d9..9a545160a10d4396802e04de0535de053dca6af0 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -222,6 +222,9 @@ class CUDNNConvOpKernel : public framework::OpKernel { dev_ctx); void* cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); + VLOG(2) << "Cudnn workspace size fwd: " + << static_cast(workspace_size_in_bytes) / (1 << 20) + << " MB"; // ------------------- cudnn conv forward --------------------- ScalingParamType alpha = 1.0f, beta = 0.0f; for (int i = 0; i < groups; i++) { @@ -473,6 +476,9 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { {static_cast(workspace_size_in_bytes)}), dev_ctx); cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); + VLOG(2) << "Cudnn workspace size bwd: " + << static_cast(workspace_size_in_bytes) / (1 << 20) + << " MB"; } // ------------------- cudnn conv backward data --------------------- diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 65c2ff6415c1d51fdc05d6014da589678761b676..273015f9763c2c7375aa0609436a2e8ab190b696 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -117,6 +117,14 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ framework::GradVarName("X")); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace()); + } }; class DropoutGradOpDescMaker : public framework::SingleGradOpDescMaker { diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 7a6927d3e54b4ece8f17d7a1e7e431ba836edff9..e26eba68f15a9934a64081fddfffd49086f7faa8 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -22,10 +22,10 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, - T* mask_data, T* dst, + MaskType* mask_data, T* dst, bool is_upscale_in_train) { thrust::minstd_rand rng; rng.seed(seed); @@ -34,7 +34,7 @@ __global__ void RandomGenerator(const size_t n, const int seed, int idx = blockDim.x * blockIdx.x + threadIdx.x; int step_size = 0; - T mask; + MaskType mask; T dest; for (; idx < n; idx += blockDim.x * gridDim.x) { T s = src[idx]; @@ -45,15 +45,16 @@ __global__ void RandomGenerator(const size_t n, const int seed, rng.discard(step_size); } if (dist(rng) < dropout_prob) { - mask = static_cast(0); + mask = 0; + dest = 0; } else { + mask = 1; if (is_upscale_in_train) { - mask = static_cast(1.0f / (1.0f - dropout_prob)); + dest = s / static_cast(1.0f - dropout_prob); } else { - mask = static_cast(1); + dest = s; } } - dest = s * mask; mask_data[idx] = mask; dst[idx] = dest; } @@ -71,30 +72,40 @@ class GPUDropoutKernel : public framework::OpKernel { y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); - auto dropout_implementation = + auto& dropout_implementation = context.Attr("dropout_implementation"); + bool upscale_in_train = (dropout_implementation == "upscale_in_train"); + auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { + int64_t x_numel = x->numel(); + auto stream = context.cuda_device_context().stream(); + auto* mask = context.Output("Mask"); - auto* mask_data = mask->mutable_data(context.GetPlace()); + auto* mask_data = mask->mutable_data(context.GetPlace()); size_t size = framework::product(mask->dims()); auto* x_data = x->data(); auto* y_data = y->mutable_data(context.GetPlace()); + if (dropout_prob == 1.0f) { + PADDLE_ENFORCE(cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); + PADDLE_ENFORCE(cudaMemsetAsync(mask_data, 0, + x_numel * sizeof(*mask_data), stream)); + return; + } std::random_device rnd; int seed = context.Attr("fix_seed") ? context.Attr("seed") : rnd(); int threads = 512; - int grid = (x->numel() + threads - 1) / threads; - RandomGenerator< - T><<>>( + int grid = (x_numel + threads - 1) / threads; + RandomGenerator<<>>( size, seed, dropout_prob, x_data, mask_data, y_data, - (dropout_implementation == "upscale_in_train")); + upscale_in_train); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); - if (dropout_implementation == "upscale_in_train") { + if (upscale_in_train) { Y.device(place) = X; } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 6c629b7b6d255828023ed25680675ca104a33e12..09c4899c7376700fbeb3ca9735e9456138b9a08e 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include @@ -37,11 +38,20 @@ class CPUDropoutKernel : public framework::OpKernel { auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); - auto dropout_implementation = + auto& dropout_implementation = context.Attr("dropout_implementation"); + bool upscale_in_train = (dropout_implementation == "upscale_in_train"); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); - auto* mask_data = mask->mutable_data(context.GetPlace()); + auto* mask_data = mask->mutable_data(context.GetPlace()); + size_t size = framework::product(mask->dims()); + + // Special case when dropout_prob is 1.0 + if (dropout_prob == 1.0f) { + std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT + std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT + return; + } // NOTE: fixed seed should only be used in unittest or for debug. // Guarantee to use random seed in training. @@ -53,17 +63,15 @@ class CPUDropoutKernel : public framework::OpKernel { std::uniform_real_distribution dist(0, 1); - size_t size = framework::product(mask->dims()); for (size_t i = 0; i < size; ++i) { if (dist(engine) < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { - if (dropout_implementation == "upscale_in_train") { - mask_data[i] = 1.0f / static_cast(1.0f - dropout_prob); + mask_data[i] = 1; + if (upscale_in_train) { y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); } else { - mask_data[i] = 1; y_data[i] = x_data[i]; } } @@ -73,7 +81,7 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - if (dropout_implementation == "upscale_in_train") { + if (upscale_in_train) { Y.device(place) = X; } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); @@ -94,13 +102,26 @@ class DropoutGradKernel : public framework::OpKernel { auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); - auto M = EigenMatrix::Reshape(*mask, 1); + auto M = EigenMatrix::Reshape(*mask, 1); auto dX = EigenMatrix::Reshape(*grad_x, 1); auto dY = EigenMatrix::Reshape(*grad_y, 1); auto& place = *context.template device_context().eigen_device(); - dX.device(place) = dY * M; + + auto& dropout_implementation = + context.Attr("dropout_implementation"); + if (dropout_implementation == "upscale_in_train") { + float dropout_prob = context.Attr("dropout_prob"); + if (dropout_prob == 1.0f) { + dX.device(place) = static_cast(0) * dY; + } else { + dX.device(place) = + dY * M.cast() / static_cast(1.0f - dropout_prob); + } + } else { + dX.device(place) = dY * M.cast(); + } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cd3b7354ed9eb2a0a0542ff0a18f1c9922e4cbe2..37997159b4ce13fb1ce14194088f904c906842a7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1390,7 +1390,7 @@ def dropout(x, helper = LayerHelper('dropout', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) mask = helper.create_variable_for_type_inference( - dtype=x.dtype, stop_gradient=True) + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) if (seed is None or seed == 0) and helper.main_program.random_seed != 0: seed = helper.main_program.random_seed diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index be3c5f3b9558ec522803ed9a5acedea75cda6ccc..59918a7bb21c42359f7d6c4f6109ca4b1cdc4449 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -27,7 +27,7 @@ class TestDropoutOp(OpTest): self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} self.outputs = { 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64)).astype('float32') + 'Mask': np.ones((32, 64)).astype('uint8') } def test_check_output(self): @@ -44,7 +44,7 @@ class TestDropoutOp2(TestDropoutOp): self.attrs = {'dropout_prob': 1.0, 'fix_seed': True, 'is_test': False} self.outputs = { 'Out': np.zeros((32, 64)).astype('float32'), - 'Mask': np.zeros((32, 64)).astype('float32') + 'Mask': np.zeros((32, 64)).astype('uint8') } @@ -55,7 +55,7 @@ class TestDropoutOp3(TestDropoutOp): self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False} self.outputs = { 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('float32') + 'Mask': np.ones((32, 64, 2)).astype('uint8') } @@ -97,7 +97,7 @@ class TestDropoutOp6(TestDropoutOp): } self.outputs = { 'Out': np.zeros((32, 64)).astype('float32'), - 'Mask': np.zeros((32, 64)).astype('float32') + 'Mask': np.zeros((32, 64)).astype('uint8') } @@ -113,7 +113,7 @@ class TestDropoutOp7(TestDropoutOp): } self.outputs = { 'Out': self.inputs['X'], - 'Mask': np.ones((32, 64, 2)).astype('float32') + 'Mask': np.ones((32, 64, 2)).astype('uint8') }