提交 d35417e7 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #4216 from qingqing01/attr_bool

Add bool type for attribute and use it in dropout_op.
...@@ -28,6 +28,10 @@ ProgramDesc& GetProgramDesc() { ...@@ -28,6 +28,10 @@ ProgramDesc& GetProgramDesc() {
return *g_program_desc; return *g_program_desc;
} }
template <>
AttrType AttrTypeID<bool>() {
return BOOLEAN;
}
template <> template <>
AttrType AttrTypeID<int>() { AttrType AttrTypeID<int>() {
return INT; return INT;
...@@ -41,6 +45,10 @@ AttrType AttrTypeID<std::string>() { ...@@ -41,6 +45,10 @@ AttrType AttrTypeID<std::string>() {
return STRING; return STRING;
} }
template <> template <>
AttrType AttrTypeID<std::vector<bool>>() {
return BOOLEANS;
}
template <>
AttrType AttrTypeID<std::vector<int>>() { AttrType AttrTypeID<std::vector<int>>() {
return INTS; return INTS;
} }
...@@ -63,6 +71,9 @@ AttrType AttrTypeID<BlockDesc>() { ...@@ -63,6 +71,9 @@ AttrType AttrTypeID<BlockDesc>() {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: {
return attr_desc.b();
}
case framework::AttrType::INT: { case framework::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
} }
...@@ -72,6 +83,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { ...@@ -72,6 +83,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
case framework::AttrType::STRING: { case framework::AttrType::STRING: {
return attr_desc.s(); return attr_desc.s();
} }
case framework::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
return val;
}
case framework::AttrType::INTS: { case framework::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size()); std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) { for (int i = 0; i < attr_desc.ints_size(); ++i) {
......
...@@ -27,8 +27,9 @@ limitations under the License. */ ...@@ -27,8 +27,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>, typedef boost::variant<boost::blank, bool, int, float, std::string,
std::vector<float>, std::vector<std::string>, std::vector<bool>, std::vector<int>, std::vector<float>,
std::vector<std::string>,
std::vector<std::pair<int, int>>, BlockDesc*> std::vector<std::pair<int, int>>, BlockDesc*>
Attribute; Attribute;
......
...@@ -23,7 +23,9 @@ enum AttrType { ...@@ -23,7 +23,9 @@ enum AttrType {
FLOATS = 4; FLOATS = 4;
STRINGS = 5; STRINGS = 5;
INT_PAIRS = 6; INT_PAIRS = 6;
BLOCK = 7; BOOLEAN = 7;
BOOLEANS = 8;
BLOCK = 9;
} }
message IntPair { message IntPair {
...@@ -45,7 +47,9 @@ message OpDesc { ...@@ -45,7 +47,9 @@ message OpDesc {
repeated float floats = 7; repeated float floats = 7;
repeated string strings = 8; repeated string strings = 8;
repeated IntPair int_pairs = 9; repeated IntPair int_pairs = 9;
optional int32 block_idx = 10; optional bool b = 10;
repeated bool bools = 11;
optional int32 block_idx = 12;
}; };
message Var { message Var {
......
...@@ -33,19 +33,16 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -33,19 +33,16 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2, PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2."); "Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must " "The 1st dimension of Input(X) and Input(Label) must "
"be equal."); "be equal.");
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of " "If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal."); "Input(X) and Input(Label) must be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of " "If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1."); "Input(Label) must be 1.");
} }
...@@ -73,9 +70,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -73,9 +70,6 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2."); PADDLE_ENFORCE_EQ(dy->dims().size(), 2, "Input(Y@Grad)'s rank must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 2, PADDLE_ENFORCE_EQ(label->dims().size(), 2,
"Input(Label)'s rank must be 2."); "Input(Label)'s rank must be 2.");
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("soft_label") == 0 ||
ctx.Attr<int>("soft_label") == 1);
PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0], PADDLE_ENFORCE_EQ(x->dims()[0], label->dims()[0],
"The 1st dimension of Input(X) and Input(Label) must " "The 1st dimension of Input(X) and Input(Label) must "
"be equal."); "be equal.");
...@@ -84,13 +78,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -84,13 +78,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"be equal."); "be equal.");
PADDLE_ENFORCE_EQ(dy->dims()[1], 1, PADDLE_ENFORCE_EQ(dy->dims()[1], 1,
"The 2nd dimension of Input(Y@Grad) must be 1."); "The 2nd dimension of Input(Y@Grad) must be 1.");
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1], PADDLE_ENFORCE_EQ(x->dims()[1], label->dims()[1],
"If Attr(soft_label) == 1, The 2nd dimension of " "If Attr(soft_label) == true, The 2nd dimension of "
"Input(X) and Input(Label) must be equal."); "Input(X) and Input(Label) must be equal.");
} else { } else {
PADDLE_ENFORCE_EQ(label->dims()[1], 1, PADDLE_ENFORCE_EQ(label->dims()[1], 1,
"If Attr(soft_label) == 0, The 2nd dimension of " "If Attr(soft_label) == false, The 2nd dimension of "
"Input(Label) must be 1."); "Input(Label) must be 1.");
} }
...@@ -107,7 +101,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -107,7 +101,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of CrossEntropyOp"); AddInput("X", "The first input of CrossEntropyOp");
AddInput("Label", "The second input of CrossEntropyOp"); AddInput("Label", "The second input of CrossEntropyOp");
AddOutput("Y", "The output of CrossEntropyOp"); AddOutput("Y", "The output of CrossEntropyOp");
AddAttr<int>("soft_label", "Is soft label. Default zero.").SetDefault(0); AddAttr<bool>("soft_label", "Is soft label. Default zero.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
CrossEntropy Operator. CrossEntropy Operator.
...@@ -115,12 +110,12 @@ CrossEntropy Operator. ...@@ -115,12 +110,12 @@ CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss It supports both standard cross-entropy and soft-label cross-entropy loss
computation. computation.
1) One-hot cross-entropy: 1) One-hot cross-entropy:
soft_label = 0, Label[i, 0] indicates the class index for sample i: soft_label = False, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]]) Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy: 2) Soft-label cross-entropy:
soft_label = 1, Label[i, j] indicates the soft label of class j soft_label = True, Label[i, j] indicates the soft label of class j
for sample i: for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
......
...@@ -102,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -102,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream // TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext. // base on ExecutionContext.
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n, SoftCrossEntropyKernel<T><<<grid, block>>>(y_data, x_data, label_data, n,
d); d);
...@@ -137,7 +137,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { ...@@ -137,7 +137,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
grid = (n + block - 1) / block; grid = (n + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream // TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext. // base on ExecutionContext.
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block>>>( SoftCrossEntropyGradientKernel<T><<<grid, block>>>(
dx_data, dy_data, x_data, label_data, n, d); dx_data, dy_data, x_data, label_data, n, d);
......
...@@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -51,7 +51,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
int batch_size = x->dims()[0]; int batch_size = x->dims()[0];
int class_num = x->dims()[1]; int class_num = x->dims()[1];
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0; int index = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -92,7 +92,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -92,7 +92,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
int class_num = x->dims()[1]; int class_num = x->dims()[1];
// TODO(qingqing): make zero setting an common function. // TODO(qingqing): make zero setting an common function.
if (ctx.Attr<int>("soft_label") == 1) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = ctx.Input<Tensor>("Label")->data<T>(); auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
int index = 0; int index = 0;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
......
...@@ -28,13 +28,10 @@ class DropoutOp : public framework::OperatorWithKernel { ...@@ -28,13 +28,10 @@ class DropoutOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto dims = ctx.Input<Tensor>("X")->dims(); auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<Tensor>("Out")->Resize(dims); ctx.Output<Tensor>("Out")->Resize(dims);
if (ctx.Attr<int>("is_training") == 1) { if (ctx.Attr<bool>("is_training")) {
ctx.Output<Tensor>("Mask")->Resize(dims); ctx.Output<Tensor>("Mask")->Resize(dims);
} }
ctx.ShareLoD("X", /*->*/ "Out"); ctx.ShareLoD("X", /*->*/ "Out");
...@@ -49,8 +46,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,8 +46,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.") AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f); .SetDefault(.5f);
// TODO(xinghai-sun): use bool for is_training after bool is supported. AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("is_training", "Whether in training phase.").SetDefault(1);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
AddInput("X", "The input of dropout op."); AddInput("X", "The input of dropout op.");
AddOutput("Out", "The output of dropout op."); AddOutput("Out", "The output of dropout op.");
...@@ -59,7 +55,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,7 +55,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Dropout Operator. Dropout Operator.
"Dropout" refers to randomly dropping out units in a nerual network. It is a 'Dropout' refers to randomly dropping out units in a nerual network. It is a
regularization technique for reducing overfitting by preventing neuron regularization technique for reducing overfitting by preventing neuron
co-adaption during training. The dropout operator randomly set (according to co-adaption during training. The dropout operator randomly set (according to
the given dropout probability) the outputs of some units to zero, while others the given dropout probability) the outputs of some units to zero, while others
...@@ -75,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -75,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Attr<int>("is_training"), 1, PADDLE_ENFORCE(ctx.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true"); "GradOp is only callable when is_training is true");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null.");
...@@ -85,9 +81,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -85,9 +81,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0); PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1); PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
ctx.Attr<int>("is_training") == 1);
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE_EQ(x_dims, out_dims, PADDLE_ENFORCE_EQ(x_dims, out_dims,
......
...@@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel { ...@@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel {
auto Y = EigenMatrix<T>::Reshape(*y, 1); auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
if (context.Attr<int>("is_training") == 1) { if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int size = framework::product(mask->dims()); int size = framework::product(mask->dims());
......
...@@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel { ...@@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel {
auto* y_data = y->mutable_data<T>(context.GetPlace()); auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob"); AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
if (context.Attr<int>("is_training") == 1) { if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask"); auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace()); auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed"); int seed = context.Attr<int>("seed");
...@@ -65,8 +65,8 @@ template <typename Place, typename T> ...@@ -65,8 +65,8 @@ template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel { class DropoutGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(context.Attr<int>("is_training"), 1, PADDLE_ENFORCE(context.Attr<bool>("is_training"),
"GradOp is only callable when is_training is true"); "GradOp is only callable when is_training is true");
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X")); auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out")); auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -89,12 +89,16 @@ class OpDescCreationMethod(object): ...@@ -89,12 +89,16 @@ class OpDescCreationMethod(object):
new_attr.f = user_defined_attr new_attr.f = user_defined_attr
elif attr.type == framework_pb2.STRING: elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOLEAN:
new_attr.b = user_defined_attr
elif attr.type == framework_pb2.INTS: elif attr.type == framework_pb2.INTS:
new_attr.ints.extend(user_defined_attr) new_attr.ints.extend(user_defined_attr)
elif attr.type == framework_pb2.FLOATS: elif attr.type == framework_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr) new_attr.floats.extend(user_defined_attr)
elif attr.type == framework_pb2.STRINGS: elif attr.type == framework_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr) new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.BOOLEANS:
new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS: elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr: for p in user_defined_attr:
pair = new_attr.int_pairs.add() pair = new_attr.int_pairs.add()
......
...@@ -24,15 +24,15 @@ class TestCosSimOp(OpTest): ...@@ -24,15 +24,15 @@ class TestCosSimOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.06)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( self.check_grad(
['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X")) ['Y'], 'Out', max_relative_error=0.06, no_grad_set=set("X"))
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.06, no_grad_set=set('Y'))
class TestCosSimOp2(TestCosSimOp): class TestCosSimOp2(TestCosSimOp):
......
...@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest): ...@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype="float32") dtype="float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {'soft_label': 0} self.attrs = {'soft_label': False}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -45,7 +45,7 @@ class TestCrossEntropyOp2(OpTest): ...@@ -45,7 +45,7 @@ class TestCrossEntropyOp2(OpTest):
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy} self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1} self.attrs = {'soft_label': True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -76,7 +76,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -76,7 +76,7 @@ class TestCrossEntropyOp3(OpTest):
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {'X': X, 'Label': label} self.inputs = {'X': X, 'Label': label}
self.outputs = {'Y': cross_entropy} self.outputs = {'Y': cross_entropy}
self.attrs = {'soft_label': 1} self.attrs = {'soft_label': True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
...@@ -7,7 +7,7 @@ class TestDropoutOp(OpTest): ...@@ -7,7 +7,7 @@ class TestDropoutOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': 1} self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))}
def test_check_output(self): def test_check_output(self):
...@@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp): ...@@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 1.0, 'is_training': 1} self.attrs = {'dropout_prob': 1.0, 'is_training': True}
self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))}
...@@ -29,7 +29,7 @@ class TestDropoutOp3(TestDropoutOp): ...@@ -29,7 +29,7 @@ class TestDropoutOp3(TestDropoutOp):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'is_training': 1} self.attrs = {'dropout_prob': 0.0, 'is_training': True}
self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))} self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))}
...@@ -37,7 +37,7 @@ class TestDropoutOp4(OpTest): ...@@ -37,7 +37,7 @@ class TestDropoutOp4(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64)).astype("float32")} self.inputs = {'X': np.random.random((32, 64)).astype("float32")}
self.attrs = {'dropout_prob': 0.35, 'is_training': 0} self.attrs = {'dropout_prob': 0.35, 'is_training': False}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
def test_check_output(self): def test_check_output(self):
...@@ -48,7 +48,7 @@ class TestDropoutOp5(OpTest): ...@@ -48,7 +48,7 @@ class TestDropoutOp5(OpTest):
def setUp(self): def setUp(self):
self.op_type = "dropout" self.op_type = "dropout"
self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")}
self.attrs = {'dropout_prob': 0.75, 'is_training': 0} self.attrs = {'dropout_prob': 0.75, 'is_training': False}
self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']}
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册