提交 72ba0270 编写于 作者: D dangqingqing

Add bool type for attribute and use it in dropout_op.

上级 7ee916b0
......@@ -21,7 +21,7 @@ namespace framework {
template <>
AttrType AttrTypeID<bool>() {
return BOOL;
return BOOLEAN;
}
template <>
AttrType AttrTypeID<int>() {
......@@ -37,7 +37,7 @@ AttrType AttrTypeID<std::string>() {
}
template <>
AttrType AttrTypeID<std::vector<bool>>() {
return BOOLS;
return BOOLEANS;
}
template <>
AttrType AttrTypeID<std::vector<int>>() {
......@@ -58,7 +58,7 @@ AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::BOOL: {
case paddle::framework::AttrType::BOOLEAN: {
return attr_desc.b();
}
case paddle::framework::AttrType::INT: {
......@@ -70,7 +70,7 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
case paddle::framework::AttrType::STRING: {
return attr_desc.s();
}
case paddle::framework::AttrType::BOOLS: {
case paddle::framework::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
......
......@@ -23,8 +23,8 @@ enum AttrType {
FLOATS = 4;
STRINGS = 5;
INT_PAIRS = 6;
BOOL = 7;
BOOLS = 8;
BOOLEAN = 7;
BOOLEANS = 8;
}
message IntPair {
......@@ -47,7 +47,7 @@ message OpDesc {
repeated string strings = 8;
repeated IntPair int_pairs = 9;
optional bool b = 10;
repeated bool bools = 6;
repeated bool bools = 11;
};
message Var {
......
......@@ -29,13 +29,10 @@ class DropoutOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<LoDTensor>("Out")->Resize(dims);
if (ctx.Attr<int>("is_training") == 1) {
if (ctx.Attr<bool>("is_training")) {
ctx.Output<LoDTensor>("Mask")->Resize(dims);
}
}
......@@ -49,8 +46,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
// TODO(xinghai-sun): use bool for is_training after bool is supported.
AddAttr<int>("is_training", "Whether in training phase.").SetDefault(1);
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op.");
......@@ -59,7 +55,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Dropout Operator.
"Dropout" refers to randomly dropping out units in a nerual network. It is a
'Dropout' refers to randomly dropping out units in a nerual network. It is a
regularization technique for reducing overfitting by preventing neuron
co-adaption during training. The dropout operator randomly set (according to
the given dropout probability) the outputs of some units to zero, while others
......@@ -75,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Attr<int>("is_training"), 1,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(ctx.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null.");
......@@ -85,9 +81,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_EQ(x_dims, out_dims,
......
......@@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel {
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>();
if (context.Attr<int>("is_training") == 1) {
if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims());
......
......@@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel {
auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
if (context.Attr<int>("is_training") == 1) {
if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
......@@ -65,8 +65,8 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(context.Attr<int>("is_training"), 1,
"GradOp is only callable when is_training is true");
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -89,7 +89,7 @@ class OpDescCreationMethod(object):
new_attr.f = user_defined_attr
elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOL:
elif attr.type == framework_pb2.BOOLEAN:
new_attr.b = user_defined_attr
elif attr.type == framework_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
......@@ -97,7 +97,7 @@ class OpDescCreationMethod(object):
new_attr.floats.extend(user_defined_attr)
elif attr.type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.BOOLS:
elif attr.type == framework_pb2.BOOLEANS:
new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr:
......
......@@ -7,7 +7,7 @@ class TestDropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': 1}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
def test_check_output(self):
......@@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'is_training': 1}
self.attrs = {'dropout_prob': 1.0, 'is_training': True}
self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
......@@ -29,7 +29,7 @@ class TestDropoutOp3(TestDropoutOp):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': 1}
self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
......@@ -37,7 +37,7 @@ class TestDropoutOp4(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_training': 0}
self.attrs = {'dropout_prob': 0.35, 'is_training': False}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
def test_check_output(self):
......@@ -48,7 +48,7 @@ class TestDropoutOp5(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_training': 0}
self.attrs = {'dropout_prob': 0.75, 'is_training': False}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册