未验证 提交 6b4a51ba 编写于 作者: M mapingshuo 提交者: GitHub

add cuda kernel for seed, test=develop (#23749)

* add cuda kernel for seed, test=develop
上级 a9fe09f8
......@@ -18,9 +18,11 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/dynload/curand.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -59,6 +61,41 @@ __global__ void RandomGenerator(const size_t n, const int seed,
}
}
template <typename T, typename MaskType>
__global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
const float dropout_prob, const T* src,
MaskType* mask_data, T* dst,
bool is_upscale_in_train) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;
MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed[0], idx, idx, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed[0], idx, step_size, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
}
}
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
......@@ -86,20 +123,6 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
auto* mask_data = mask->mutable_data<uint8_t>(context.GetPlace());
size_t size = framework::product(mask->dims());
auto* x_data = x->data<T>();
int seed_data;
std::random_device rnd;
if (seed) {
if (platform::is_gpu_place(seed->place())) {
framework::Tensor temp;
TensorCopySync(*seed, platform::CPUPlace(), &temp);
seed_data = *(temp.data<int>());
} else {
seed_data = *(seed->data<int>());
}
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}
auto* y_data = y->mutable_data<T>(context.GetPlace());
if (dropout_prob == 1.0f) {
PADDLE_ENFORCE_CUDA_SUCCESS(
......@@ -111,6 +134,22 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int threads = 512;
int grid = (x_numel + threads - 1) / threads;
if (seed && platform::is_gpu_place(seed->place())) {
auto seed_gpu_data = seed->data<int>();
RandomGeneratorWithSeed<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_gpu_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
return;
}
int seed_data;
std::random_device rnd;
if (seed) {
seed_data = *(seed->data<int>());
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}
RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
......
......@@ -30,7 +30,7 @@ class SeedOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::INT32,
platform::CPUPlace());
ctx.device_context());
}
};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include "paddle/fluid/operators/seed_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class GPUSeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
int user_seed = context.Attr<int>("seed");
std::random_device rnd;
int seed;
if (user_seed != 0) {
seed = user_seed;
} else {
seed = rnd();
}
auto target_gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
auto stream = context.cuda_device_context().stream();
memory::Copy(target_gpu_place, out_data, platform::CPUPlace(), &seed,
sizeof(int), stream);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
seed,
paddle::operators::GPUSeedKernel<paddle::platform::CUDADeviceContext, int>);
......@@ -16,10 +16,14 @@ from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.fluid.core as core
import paddle.compat as cpt
import numpy as np
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard
class TestOptimizer(unittest.TestCase):
......@@ -841,6 +845,108 @@ class TestRecomputeOptimizer(unittest.TestCase):
"sgd", "sgd", "sgd"
])
def test_dropout_with_seed(self):
"""
when we recompute a dropout op, make sure that the recomputed one
is the same as the original var.
"""
def gen_data():
return {
"x": np.random.random(size=(100, 3)).astype('float32'),
"y": np.random.randint(
2, size=(100, 1)).astype('int64')
}
def mlp(input_x, input_y):
drop_res = fluid.layers.dropout(
input_x, dropout_prob=0.5, name="dropout_with_seed_cpu")
prediction = fluid.layers.fc(input=[drop_res],
size=2,
act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return drop_res, prediction, sum_cost
main_program = Program()
startup_program = Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with program_guard(main_program, startup_program):
input_x = fluid.layers.data(
name="x", shape=[3], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
drop_res, prediction, cost = mlp(input_x, input_y)
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([prediction])
sgd.minimize(cost)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feed_data = gen_data()
drop_vec = exe.run(feed=feed_data,
program=fluid.default_main_program(),
fetch_list=[
"dropout_with_seed_cpu.tmp_1",
"dropout_with_seed_cpu.tmp_1.subprog_0"
])
self.assertEqual(drop_vec[0].tolist(), drop_vec[1].tolist())
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestRecomputeOptimizerCUDA(unittest.TestCase):
def test_dropout_with_seed(self):
"""
when we recompute a dropout op, make sure that the recomputed one
is the same as the original var.
"""
def gen_data():
return {
"x": np.random.random(size=(100, 3)).astype('float32'),
"y": np.random.randint(
2, size=(100, 1)).astype('int64')
}
def mlp(input_x, input_y):
drop_res = fluid.layers.dropout(
input_x, dropout_prob=0.5, name="dropout_with_seed_gpu")
prediction = fluid.layers.fc(input=[drop_res],
size=2,
act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
sum_cost = fluid.layers.reduce_mean(cost)
return drop_res, prediction, sum_cost
main_program = Program()
startup_program = Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with program_guard(main_program, startup_program):
input_x = fluid.layers.data(
name="x", shape=[3], dtype='float32')
input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')
drop_res, prediction, cost = mlp(input_x, input_y)
sgd = fluid.optimizer.Adam(learning_rate=0.01)
sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([prediction])
sgd.minimize(cost)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feed_data = gen_data()
drop_vec = exe.run(feed=feed_data,
program=fluid.default_main_program(),
fetch_list=[
"dropout_with_seed_gpu.tmp_1",
"dropout_with_seed_gpu.tmp_1.subprog_0"
])
self.assertEqual(drop_vec[0].tolist(), drop_vec[1].tolist())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册