From 6b4a51bae3e8bdb1266573d67adde9fb55cf86b6 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Mon, 13 Apr 2020 10:52:13 +0800 Subject: [PATCH] add cuda kernel for seed, test=develop (#23749) * add cuda kernel for seed, test=develop --- paddle/fluid/operators/dropout_op.cu | 67 ++++++++--- paddle/fluid/operators/seed_op.cc | 2 +- paddle/fluid/operators/seed_op.cu | 47 ++++++++ .../fluid/tests/unittests/test_optimizer.py | 106 ++++++++++++++++++ 4 files changed, 207 insertions(+), 15 deletions(-) create mode 100644 paddle/fluid/operators/seed_op.cu diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index d9b5572d95c..4d5e4c4f600 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -18,9 +18,11 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/dynload/curand.h" #include "paddle/fluid/platform/float16.h" + namespace paddle { namespace operators { @@ -59,6 +61,41 @@ __global__ void RandomGenerator(const size_t n, const int seed, } } +template +__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed, + const float dropout_prob, const T* src, + MaskType* mask_data, T* dst, + bool is_upscale_in_train) { + curandStatePhilox4_32_10_t state; + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int step_size = 0; + + MaskType mask; + T dest; + for (; idx < n; idx += blockDim.x * gridDim.x) { + T s = src[idx]; + if (step_size == 0) { + curand_init(seed[0], idx, idx, &state); + step_size = blockDim.x * gridDim.x; + } else { + curand_init(seed[0], idx, step_size, &state); + } + if (curand_uniform(&state) < dropout_prob) { + mask = 0; + dest = 0; + } else { + mask = 1; + if (is_upscale_in_train) { + dest = s / static_cast(1.0f - dropout_prob); + } else { + dest = s; + } + } + mask_data[idx] = mask; + dst[idx] = dest; + } +} + // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -86,20 +123,6 @@ class GPUDropoutKernel : public framework::OpKernel { auto* mask_data = mask->mutable_data(context.GetPlace()); size_t size = framework::product(mask->dims()); auto* x_data = x->data(); - int seed_data; - std::random_device rnd; - if (seed) { - if (platform::is_gpu_place(seed->place())) { - framework::Tensor temp; - TensorCopySync(*seed, platform::CPUPlace(), &temp); - seed_data = *(temp.data()); - } else { - seed_data = *(seed->data()); - } - } else { - seed_data = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - } auto* y_data = y->mutable_data(context.GetPlace()); if (dropout_prob == 1.0f) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -111,6 +134,22 @@ class GPUDropoutKernel : public framework::OpKernel { int threads = 512; int grid = (x_numel + threads - 1) / threads; + if (seed && platform::is_gpu_place(seed->place())) { + auto seed_gpu_data = seed->data(); + RandomGeneratorWithSeed<<>>( + size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data, + upscale_in_train); + return; + } + int seed_data; + std::random_device rnd; + if (seed) { + seed_data = *(seed->data()); + } else { + seed_data = + context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + } + RandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train); diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc index 86c551f4c74..2f3e4c9ba88 100644 --- a/paddle/fluid/operators/seed_op.cc +++ b/paddle/fluid/operators/seed_op.cc @@ -30,7 +30,7 @@ class SeedOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(framework::proto::VarType::INT32, - platform::CPUPlace()); + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/seed_op.cu b/paddle/fluid/operators/seed_op.cu new file mode 100644 index 00000000000..aa2c329c4d0 --- /dev/null +++ b/paddle/fluid/operators/seed_op.cu @@ -0,0 +1,47 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/seed_op.h" + +namespace paddle { +namespace operators { + +template +class GPUSeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Output("Out"); + auto* out_data = out->mutable_data(context.GetPlace()); + int user_seed = context.Attr("seed"); + std::random_device rnd; + int seed; + if (user_seed != 0) { + seed = user_seed; + } else { + seed = rnd(); + } + auto target_gpu_place = boost::get(context.GetPlace()); + auto stream = context.cuda_device_context().stream(); + memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed, + sizeof(int), stream); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + seed, + paddle::operators::GPUSeedKernel); diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index fefcac6ede7..7894fc01887 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -16,10 +16,14 @@ from __future__ import print_function import unittest +import paddle.fluid as fluid import paddle.fluid.framework as framework import paddle.fluid.optimizer as optimizer +import paddle.fluid.core as core import paddle.compat as cpt +import numpy as np from paddle.fluid.backward import append_backward +from paddle.fluid.framework import Program, program_guard class TestOptimizer(unittest.TestCase): @@ -841,6 +845,108 @@ class TestRecomputeOptimizer(unittest.TestCase): "sgd", "sgd", "sgd" ]) + def test_dropout_with_seed(self): + """ + when we recompute a dropout op, make sure that the recomputed one + is the same as the original var. + """ + + def gen_data(): + return { + "x": np.random.random(size=(100, 3)).astype('float32'), + "y": np.random.randint( + 2, size=(100, 1)).astype('int64') + } + + def mlp(input_x, input_y): + drop_res = fluid.layers.dropout( + input_x, dropout_prob=0.5, name="dropout_with_seed_cpu") + prediction = fluid.layers.fc(input=[drop_res], + size=2, + act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return drop_res, prediction, sum_cost + + main_program = Program() + startup_program = Program() + scope = fluid.Scope() + with fluid.scope_guard(scope): + with program_guard(main_program, startup_program): + input_x = fluid.layers.data( + name="x", shape=[3], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + drop_res, prediction, cost = mlp(input_x, input_y) + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + sgd._set_checkpoints([prediction]) + sgd.minimize(cost) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + feed_data = gen_data() + drop_vec = exe.run(feed=feed_data, + program=fluid.default_main_program(), + fetch_list=[ + "dropout_with_seed_cpu.tmp_1", + "dropout_with_seed_cpu.tmp_1.subprog_0" + ]) + self.assertEqual(drop_vec[0].tolist(), drop_vec[1].tolist()) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestRecomputeOptimizerCUDA(unittest.TestCase): + def test_dropout_with_seed(self): + """ + when we recompute a dropout op, make sure that the recomputed one + is the same as the original var. + """ + + def gen_data(): + return { + "x": np.random.random(size=(100, 3)).astype('float32'), + "y": np.random.randint( + 2, size=(100, 1)).astype('int64') + } + + def mlp(input_x, input_y): + drop_res = fluid.layers.dropout( + input_x, dropout_prob=0.5, name="dropout_with_seed_gpu") + prediction = fluid.layers.fc(input=[drop_res], + size=2, + act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=input_y) + sum_cost = fluid.layers.reduce_mean(cost) + return drop_res, prediction, sum_cost + + main_program = Program() + startup_program = Program() + scope = fluid.Scope() + with fluid.scope_guard(scope): + with program_guard(main_program, startup_program): + input_x = fluid.layers.data( + name="x", shape=[3], dtype='float32') + input_y = fluid.layers.data(name="y", shape=[1], dtype='int64') + drop_res, prediction, cost = mlp(input_x, input_y) + sgd = fluid.optimizer.Adam(learning_rate=0.01) + sgd = fluid.optimizer.RecomputeOptimizer(sgd) + sgd._set_checkpoints([prediction]) + sgd.minimize(cost) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + feed_data = gen_data() + drop_vec = exe.run(feed=feed_data, + program=fluid.default_main_program(), + fetch_list=[ + "dropout_with_seed_gpu.tmp_1", + "dropout_with_seed_gpu.tmp_1.subprog_0" + ]) + self.assertEqual(drop_vec[0].tolist(), drop_vec[1].tolist()) + if __name__ == '__main__': unittest.main() -- GitLab