From 2e331c6593a13a090bcce2c16992bbf0baacf980 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 19 Apr 2018 17:38:36 +0800 Subject: [PATCH] accelerate dropout (#9902) * accelerate dropout * accelerate dropout * "fix the dropout test" * "rerun ci" * "fix ci" * "rerun ci" * "fix ci" * "fix" * "stage" * disable --- paddle/fluid/operators/dropout_op.cu | 36 +++++++++++++---------- paddle/fluid/operators/dropout_op.h | 12 ++++---- paddle/fluid/operators/dropout_op_test.cc | 29 ++++++++++-------- 3 files changed, 43 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 184c095e48..490dce19b6 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -24,21 +24,11 @@ namespace paddle { namespace operators { template -__global__ void RandomGenerator(const size_t n, const int seed, - const float dropout_prob, const T* src, - T* mask_data, T* dst) { - thrust::minstd_rand rng; - rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); - +__global__ void RandomGenerator(const size_t n, const T* src, + const T* cpu_mask_data, T* mask_data, T* dst) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < n; idx += blockDim.x * gridDim.x) { - rng.discard(idx); - if (dist(rng) < dropout_prob) { - mask_data[idx] = static_cast(0); - } else { - mask_data[idx] = static_cast(1); - } + mask_data[idx] = cpu_mask_data[idx]; dst[idx] = mask_data[idx] * src[idx]; } } @@ -66,15 +56,27 @@ class GPUDropoutKernel : public framework::OpKernel { std::random_device rnd; int seed = context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + std::minstd_rand engine; + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + framework::Vector cpu_mask(size); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < dropout_prob) { + cpu_mask[i] = static_cast(0); + } else { + cpu_mask[i] = static_cast(1); + } + } int threads = 512; int grid = (x->numel() + threads - 1) / threads; RandomGenerator< T><<>>( - size, seed, dropout_prob, x_data, mask_data, y_data); + size, x_data, cpu_mask.CUDAData(context.GetPlace()), mask_data, + y_data); } else { - auto X = EigenMatrix::Reshape(*x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); + auto X = EigenVector::Flatten(*x); + auto Y = EigenVector::Flatten(*y); Y.device(place) = X * static_cast(1.0f - dropout_prob); } } @@ -87,6 +89,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, + ops::GPUDropoutKernel, ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL(dropout_grad, + ops::DropoutGradKernel, ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 0628b4b826..41ca242d8f 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -24,7 +24,7 @@ namespace operators { using Tensor = framework::Tensor; template -using EigenMatrix = framework::EigenMatrix; +using EigenVector = framework::EigenVector; template class CPUDropoutKernel : public framework::OpKernel { @@ -60,8 +60,8 @@ class CPUDropoutKernel : public framework::OpKernel { } } } else { - auto X = EigenMatrix::Reshape(*x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); + auto X = EigenVector::Flatten(*x); + auto Y = EigenVector::Flatten(*y); auto& place = *context.template device_context().eigen_device(); Y.device(place) = X * (1.0f - dropout_prob); @@ -81,9 +81,9 @@ class DropoutGradKernel : public framework::OpKernel { auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); - auto M = EigenMatrix::Reshape(*mask, 1); - auto dX = EigenMatrix::Reshape(*grad_x, 1); - auto dY = EigenMatrix::Reshape(*grad_y, 1); + auto M = EigenVector::Flatten(*mask); + auto dX = EigenVector::Flatten(*grad_x); + auto dY = EigenVector::Flatten(*grad_y); auto& place = *context.template device_context().eigen_device(); diff --git a/paddle/fluid/operators/dropout_op_test.cc b/paddle/fluid/operators/dropout_op_test.cc index 424d273c34..47ea847674 100644 --- a/paddle/fluid/operators/dropout_op_test.cc +++ b/paddle/fluid/operators/dropout_op_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include #include // NOLINT @@ -32,14 +33,16 @@ namespace m = paddle::operators::math; USE_OP(dropout); +static paddle::framework::DDim dims = {10, 10}; + void Compare(f::Scope* scope, const p::DeviceContext& ctx) { // init auto var = scope->Var("X"); auto tensor = var->GetMutable(); - tensor->Resize({10, 10}); + tensor->Resize(dims); std::vector init; - for (int64_t i = 0; i < 10 * 10; ++i) { + for (int64_t i = 0; i < f::product(dims); ++i) { init.push_back(1.0); } @@ -48,18 +51,19 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { auto place = ctx.GetPlace(); auto out_var = scope->Var("Out"); auto out_tensor = out_var->GetMutable(); - out_tensor->Resize({10, 10}); + out_tensor->Resize(dims); out_tensor->mutable_data(place); // allocate auto mask_var = scope->Var("Mask"); auto mask_tensor = mask_var->GetMutable(); - mask_tensor->Resize({10, 10}); + mask_tensor->Resize(dims); mask_tensor->mutable_data(place); // allocate // run f::AttributeMap attrs; float dropout_prob = 0.5; - attrs.insert({"fix_seed", 1}); + attrs.insert({"is_test", false}); + attrs.insert({"fix_seed", true}); attrs.insert({"seed", 3}); attrs.insert({"dropout_prob", dropout_prob}); auto dropout_op = f::OpRegistry::CreateOp( @@ -69,6 +73,7 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { std::vector out_vec; TensorToVector(*out_tensor, ctx, &out_vec); + ctx.Wait(); std::vector std_out = { 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, @@ -83,22 +88,22 @@ void Compare(f::Scope* scope, const p::DeviceContext& ctx) { } } -// TODO(wyi): Due to -// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily -// disable this test to remove the prevention of the merge of -// unrelated PRs. -/* TEST(Dropout, CPUDense) { f::Scope scope; p::CPUPlace place; p::CPUDeviceContext ctx(place); - Compare(scope, ctx); + Compare(&scope, ctx); } +// TODO(wyi, dzhwinter): Due to +// https://github.com/PaddlePaddle/Paddle/issues/9507, I temporarily +// disable this test to remove the prevention of the merge of +// unrelated PRs. +/* TEST(Dropout, GPUDense) { f::Scope scope; p::CUDAPlace place; p::CUDADeviceContext ctx(place); - Compare(scope, ctx); + Compare(&scope, ctx); } */ -- GitLab