“4d6f4df70da785e2513c19d2e99fa0c5c2f429fd”上不存在“develop/doc/api/v1/trainer_config_helpers/evaluators.html”
提交 585d12a3 编写于 作者: X Xinghai Sun

Add is_training attr and testing phrase compuation to dropout operator.

Change type of dropout_prob to template typename.
上级 32645b52
...@@ -30,6 +30,10 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,10 @@ class DropoutOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
// resize // resize
auto dims = ctx.Input<Tensor>("X")->dims(); auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<LoDTensor>("Out")->Resize(dims); ctx.Output<LoDTensor>("Out")->Resize(dims);
...@@ -37,13 +41,16 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -37,13 +41,16 @@ class DropoutOp : public framework::OperatorWithKernel {
} }
}; };
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
DropoutOpMaker(framework::OpProto *proto, DropoutOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<float>("dropout_prob", "Probability for dropping out units.") AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f); .SetDefault(.5f);
// TODO(xinghai-sun): use bool for is_training after bool is supported.
AddAttr<int>("is_training", "Whether in training phase.").SetDefault(1);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddInput("X", "The input of dropout op."); AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op."); AddOutput("Out", "The output of dropout op.");
...@@ -61,6 +68,7 @@ being set to their inputs. ...@@ -61,6 +68,7 @@ being set to their inputs.
} }
}; };
template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel { class DropoutOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -72,8 +80,11 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -72,8 +80,11 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto mask_dims = ctx.Input<Tensor>("Mask")->dims(); auto mask_dims = ctx.Input<Tensor>("Mask")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
...@@ -91,9 +102,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -91,9 +102,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad); ops::DropoutOpGrad<float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float>); dropout, ops::CPUDropoutKernel<paddle::platform::CPUPlace, float, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>); dropout_grad, ops::DropoutGradKernel<paddle::platform::CPUPlace, float>);
...@@ -22,18 +22,18 @@ ...@@ -22,18 +22,18 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename AttrType>
struct MaskGenerator { struct MaskGenerator {
float dropout_prob; AttrType dropout_prob;
int seed; int seed;
__host__ __device__ MaskGenerator(float dropout_prob, int seed) __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
: dropout_prob(dropout_prob), seed(seed) {} : dropout_prob(dropout_prob), seed(seed) {}
__host__ __device__ T operator()(const unsigned int n) const { __host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed); rng.seed(seed);
thrust::uniform_real_distribution<T> dist(0, 1); thrust::uniform_real_distribution<AttrType> dist(0, 1);
rng.discard(n); rng.discard(n);
if (dist(rng) < dropout_prob) { if (dist(rng) < dropout_prob) {
return static_cast<T>(0); return static_cast<T>(0);
...@@ -46,33 +46,35 @@ struct MaskGenerator { ...@@ -46,33 +46,35 @@ struct MaskGenerator {
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename Place, typename T> template <typename Place, typename T, typename AttrType>
class GPUDropoutKernel : public framework::OpKernel { class GPUDropoutKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out"); auto* y = context.Output<Tensor>("Out");
auto* mask = context.Output<Tensor>("Mask");
y->mutable_data<T>(context.GetPlace()); y->mutable_data<T>(context.GetPlace());
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob"); auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto M = EigenMatrix<T>::Reshape(*mask, 1);
auto place = context.GetEigenDevice<Place>();
int size = framework::product(mask->dims());
if (context.Attr<int>("is_training") == 1) {
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int size = framework::product(mask->dims());
T* mask_data = mask->mutable_data<T>(context.GetPlace());
thrust::transform(index_sequence_begin, index_sequence_begin + size, thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(mask_data), thrust::device_ptr<T>(mask_data),
MaskGenerator<T>(dropout_prob, seed)); MaskGenerator<T, AttrType>(dropout_prob, seed));
auto dims = x->dims();
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto X = EigenMatrix<T>::From(*x, new_dims);
auto Y = EigenMatrix<T>::From(*y, new_dims);
auto M = EigenMatrix<T>::From(*mask, new_dims);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = X * M; Y.device(place) = X * M;
// TODO(xinghai-sun): add test time logits. } else {
cudaMemset(mask_data, 0, sizeof(T) * size);
Y.device(place) = X * dropout_prob;
}
} }
}; };
...@@ -81,6 +83,6 @@ class GPUDropoutKernel : public framework::OpKernel { ...@@ -81,6 +83,6 @@ class GPUDropoutKernel : public framework::OpKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float>); dropout, ops::GPUDropoutKernel<paddle::platform::GPUPlace, float, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>); dropout_grad, ops::DropoutGradKernel<paddle::platform::GPUPlace, float>);
...@@ -25,23 +25,24 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,23 +25,24 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel { class CPUDropoutKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out"); auto* y = context.Output<Tensor>("Out");
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
T* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
T* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>(); const auto* x_data = x->data<T>();
float dropout_prob = context.Attr<float>("dropout_prob"); AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
int seed = context.Attr<int>("seed");
if (context.Attr<int>("is_training") == 1) {
int seed = context.Attr<int>("seed");
std::minstd_rand engine; std::minstd_rand engine;
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<T> dist(0, 1); std::uniform_real_distribution<AttrType> dist(0, 1);
size_t size = framework::product(mask->dims()); size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) { if (dist(engine) < dropout_prob) {
...@@ -52,7 +53,14 @@ class CPUDropoutKernel : public framework::OpKernel { ...@@ -52,7 +53,14 @@ class CPUDropoutKernel : public framework::OpKernel {
y_data[i] = x_data[i]; y_data[i] = x_data[i];
} }
} }
// TODO: add test phase logits. } else {
size_t size = framework::product(mask->dims());
memset(mask_data, 0, sizeof(T) * size);
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = X * dropout_prob;
}
} }
}; };
...@@ -60,21 +68,19 @@ template <typename Place, typename T> ...@@ -60,21 +68,19 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel { class DropoutGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(context.Attr<int>("is_training"), 1,
"Only callable when is_training is true");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X")); auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out")); auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask"); auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace()); grad_x->mutable_data<T>(context.GetPlace());
auto dims = grad_x->dims(); auto M = EigenMatrix<T>::Reshape(*mask, 1);
int size = static_cast<int>(framework::product(dims)); auto dX = EigenMatrix<T>::Reshape(*grad_x, 1);
auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); auto dY = EigenMatrix<T>::Reshape(*grad_y, 1);
auto M = EigenMatrix<T>::From(*mask, new_dims);
auto dX = EigenMatrix<T>::From(*grad_x, new_dims);
auto dY = EigenMatrix<T>::From(*grad_y, new_dims);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
dX.device(place) = dY * M; dX.device(place) = dY * M;
// TODO: add test time logits.
} }
}; };
......
...@@ -7,7 +7,7 @@ class TestDropoutOp(OpTest): ...@@ -7,7 +7,7 @@ class TestDropoutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0} self.attrs = {'dropout_prob': 0.0, 'is_training': 1}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
def test_check_output(self): def test_check_output(self):
...@@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp): ...@@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0} self.attrs = {'dropout_prob': 1.0, 'is_training': 1}
self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
...@@ -29,9 +29,37 @@ class TestDropoutOp3(TestDropoutOp): ...@@ -29,9 +29,37 @@ class TestDropoutOp3(TestDropoutOp):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0} self.attrs = {'dropout_prob': 0.0, 'is_training': 1}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))} self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
class TestDropoutOp4(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_training': 0}
self.outputs = {
'Out': self.inputs['X'] * self.attrs['dropout_prob'],
'Mask': np.zeros((32, 64))
}
def test_check_output(self):
self.check_output()
class TestDropoutOp5(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_training': 0}
self.outputs = {
'Out': self.inputs['X'] * self.attrs['dropout_prob'],
'Mask': np.zeros((32, 64, 3))
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册