提交 32645b52 编写于 作者: X Xinghai Sun

Move dropout gpu kernel to dropout_op.cu.

上级 05326629
...@@ -13,8 +13,72 @@ ...@@ -13,8 +13,72 @@
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/dropout_op.h" #include "paddle/operators/dropout_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct MaskGenerator {
float dropout_prob;
int seed;
__host__ __device__ MaskGenerator(float dropout_prob, int seed)
: dropout_prob(dropout_prob), seed(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(0, 1);
rng.discard(n);
if (dist(rng) < dropout_prob) {
return static_cast<T>(0);
} else {
return static_cast<T>(1);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
auto* mask = context.Output<Tensor>("Mask");
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
int seed = context.Attr<int>("seed");
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int size = framework::product(mask->dims());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(mask_data),
MaskGenerator<T>(dropout_prob, seed));
auto dims = x->dims();
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto X = EigenMatrix<T>::From(*x, new_dims);
auto Y = EigenMatrix<T>::From(*y, new_dims);
auto M = EigenMatrix<T>::From(*mask, new_dims);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = X * M;
// TODO(xinghai-sun): add test time logits.
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float>); dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float>);
......
...@@ -13,10 +13,6 @@ ...@@ -13,10 +13,6 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <random> #include <random>
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -60,60 +56,6 @@ class CPUDropoutKernel : public framework::OpKernel { ...@@ -60,60 +56,6 @@ class CPUDropoutKernel : public framework::OpKernel {
} }
}; };
template <typename T>
struct MaskGenerator {
float dropout_prob;
int seed;
__host__ __device__ MaskGenerator(float dropout_prob, int seed)
: dropout_prob(dropout_prob), seed(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<T> dist(0, 1);
rng.discard(n);
if (dist(rng) < dropout_prob) {
return static_cast<T>(0);
} else {
return static_cast<T>(1);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
auto* mask = context.Output<Tensor>("Mask");
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
int seed = context.Attr<int>("seed");
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int size = framework::product(mask->dims());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(mask_data),
MaskGenerator<T>(dropout_prob, seed));
auto dims = x->dims();
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto X = EigenMatrix<T>::From(*x, new_dims);
auto Y = EigenMatrix<T>::From(*y, new_dims);
auto M = EigenMatrix<T>::From(*mask, new_dims);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = X * M;
// TODO: add test time logits.
}
};
template <typename Place, typename T> template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel { class DropoutGradKernel : public framework::OpKernel {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册