提交 b1a18552 编写于 作者: X Xinghai Sun

Fixed SEGFAULT of dropout operator in GPU.

上级 9a44f3d6
......@@ -37,6 +37,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("dropout_prob", "Dropout probability.").SetDefault(.5f);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The dropout mask.").AsIntermediate();
......@@ -75,7 +77,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(dropout,
ops::DropoutKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>);
......@@ -16,7 +16,7 @@
#include "paddle/operators/dropout_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(dropout,
ops::DropoutKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>);
......@@ -13,6 +13,11 @@
limitations under the License. */
#pragma once
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <random>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
......@@ -25,25 +30,85 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class DropoutKernel : public framework::OpKernel {
class CPUDropoutKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
auto* mask = context.Output<Tensor>("Mask");
T* mask_data = mask->mutable_data<T>(context.GetPlace());
T* y_data = y->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
int seed = context.op_.GetAttr<int>("seed");
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<T> dist(0, 1);
size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
mask_data[i] = 1;
y_data[i] = (1 - dropout_prob) * x_data[i];
}
}
}
};
template <typename T>
struct MaskGenerator {
float dropout_prob_;
int seed_;
__host__ __device__ MaskGenerator(float dropout_prob, int seed)
: dropout_prob_(dropout_prob), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0, 1);
rng.discard(n);
if (dist(rng) < dropout_prob_) {
return static_cast<T>(0);
} else {
return static_cast<T>(1);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
auto* mask = context.Output<Tensor>("Mask");
mask->mutable_data<T>(context.GetPlace());
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
int seed = context.op_.GetAttr<int>("seed");
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int size = framework::product(mask->dims());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(mask_data),
MaskGenerator<T>(dropout_prob, seed));
auto dims = x->dims();
auto X = EigenMatrix<T>::From(*x);
auto Y = EigenMatrix<T>::From(*y);
auto M = EigenMatrix<T>::From(*mask);
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto X = EigenMatrix<T>::From(*x, new_dims);
auto Y = EigenMatrix<T>::From(*y, new_dims);
auto M = EigenMatrix<T>::From(*mask, new_dims);
auto place = context.GetEigenDevice<Place>();
M.device(place).setRandom<UniformRandomGenerator>();
float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
M.device(place) = (M > dropout_prob).cast<float>();
Y.device(place) = X * Y;
Y.device(place) = X * M * (1 - dropout_prob);
}
};
......@@ -57,12 +122,15 @@ class DropoutGradKernel : public framework::OpKernel {
grad_x->mutable_data<T>(context.GetPlace());
auto dims = grad_x->dims();
auto M = EigenMatrix<T>::From(*mask);
auto dX = EigenMatrix<T>::From(*grad_x);
auto dY = EigenMatrix<T>::From(*grad_y);
int size = static_cast<int>(framework::product(dims));
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto M = EigenMatrix<T>::From(*mask, new_dims);
auto dX = EigenMatrix<T>::From(*grad_x, new_dims);
auto dY = EigenMatrix<T>::From(*grad_y, new_dims);
auto place = context.GetEigenDevice<Place>();
dX.device(place) = dY * M;
float dropout_prob = context.op_.GetAttr<float>("dropout_prob");
dX.device(place) = dY * M * (1 - dropout_prob);
}
};
......
......@@ -4,6 +4,7 @@ py_test(test_scope SRCS test_scope.py)
py_test(test_tensor SRCS test_tensor.py)
py_test(test_mul_op SRCS test_mul_op.py)
py_test(test_dropout_op SRCS test_dropout_op.py)
py_test(test_mean_op SRCS test_mean_op.py)
......
......@@ -6,13 +6,13 @@ from paddle.v2.framework.op import Operator
class OpTestMeta(type):
"""
Operator Test ClassMeta.
It injects `test_all` method into user's OperatorTest class, to make Python
It injects `test_all` method into user's OperatorTest class, to make Python
unittest module run that method.
The `test_all` read what value is stored in `self`. It use self's values to
create and run a operator, and check whether that op is OK or not.
See `test_add_two_op` for example usage.
"""
......
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestDropoutOpProbZero(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
class TestDropoutOpAllProbOne(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0}
self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
class DropoutGradOpTest(GradientChecker):
def test_dropout_2d(self):
op = create_op("dropout")
inputs = {'X': np.random.random((10, 5)).astype("float32")}
self.compare_grad(op, inputs)
self.check_grad(op, inputs, set(["X"]), "Out")
def test_dropout_3d(self):
op = create_op("dropout")
inputs = {'X': np.random.random((10, 5, 4)).astype("float32")}
self.compare_grad(op, inputs)
self.check_grad(op, inputs, set(["X"]), "Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册