提交 c7f91a94 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #3817 from xinghai-sun/dropout

Add dropout operator.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/dropout_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
class DropoutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<LoDTensor>("Out")->Resize(dims);
if (ctx.Attr<int>("is_training") == 1) {
ctx.Output<LoDTensor>("Mask")->Resize(dims);
}
}
};
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
// TODO(xinghai-sun): use bool for is_training after bool is supported.
AddAttr<int>("is_training", "Whether in training phase.").SetDefault(1);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
AddComment(R"DOC(
Dropout Operator.
"Dropout" refers to randomly dropping out units in a nerual network. It is a
regularization technique for reducing overfitting by preventing neuron
co-adaption during training. The dropout operator randomly set (according to
the given dropout probability) the outputs of some units to zero, while others
being set to their inputs.
)DOC");
}
};
template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Attr<int>("is_training"), 1,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_EQ(x_dims, out_dims,
"Dimensions of Input(X) and Out@Grad must be the same.");
auto mask_dims = ctx.Input<Tensor>("Mask")->dims();
PADDLE_ENFORCE_EQ(x_dims, mask_dims,
"Dimensions of Input(X) and Mask must be the same.");
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>);
REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/operators/dropout_op.h"
namespace paddle {
namespace operators {
template <typename T, typename AttrType>
struct MaskGenerator {
AttrType dropout_prob;
int seed;
__host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
: dropout_prob(dropout_prob), seed(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1);
rng.discard(n);
if (dist(rng) < dropout_prob) {
return static_cast<T>(0);
} else {
return static_cast<T>(1);
}
}
};
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T, typename AttrType>
class GPUDropoutKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
if (context.Attr<int>("is_training") == 1) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims());
int seed = context.Attr<int>("seed");
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(mask_data),
MaskGenerator<T, AttrType>(dropout_prob, seed));
auto M = EigenMatrix<T>::Reshape(*mask, 1);
Y.device(place) = X * M;
} else {
Y.device(place) = X * dropout_prob;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float, float>);
REGISTER_OP_GPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <random>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
if (context.Attr<int>("is_training") == 1) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<AttrType> dist(0, 1);
size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
mask_data[i] = 1;
y_data[i] = x_data[i];
}
}
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = X * dropout_prob;
}
}
};
template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(context.Attr<int>("is_training"), 1,
"GradOp is only callable when is_training is true");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto M = EigenMatrix<T>::Reshape(*mask, 1);
auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
auto place = context.GetEigenDevice<Place>();
dX.device(place) = dY * M;
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestDropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': 1}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.05)
class TestDropoutOp2(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'is_training': 1}
self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
class TestDropoutOp3(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': 1}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
class TestDropoutOp4(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_training': 0}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
def test_check_output(self):
self.check_output()
class TestDropoutOp5(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_training': 0}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册