提交 d03dbb97 编写于 作者: K Kexin Zhao

remove AttrType

上级 05ad1583
......@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
}
};
template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>);
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(
dropout,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -23,13 +23,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T, typename AttrType>
template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed,
const AttrType dropout_prob, const T* src,
const float dropout_prob, const T* src,
T* mask_data, T* dst) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1);
thrust::uniform_real_distribution<float> dist(0, 1);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < n; idx += blockDim.x * gridDim.x) {
......@@ -45,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T, typename AttrType>
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob"));
float dropout_prob = context.Attr<float>("dropout_prob");
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
......@@ -71,8 +71,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int threads = 512;
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<T, AttrType><<<grid, threads, 0,
context.cuda_device_context().stream()>>>(
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data);
} else {
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
......@@ -86,7 +86,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16, float>);
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(dropout_grad,
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T, typename AttrType>
template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册