diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index 6faa9eb225c76735460227b77387d0b0e8157525..76c257c6d34f0a82a920eaf49c1ef88efbd0daf4 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -82,7 +82,20 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.variance = scope->FindVar(op_desc.Input("Variance").front())->GetMutable(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); - param_.is_test = op_desc.GetAttr("is_test"); + + auto is_test_type = op_desc.GetAttrType("is_test"); + switch (is_test_type) { + case OpDescAPI::AttrType::INT: + param_.is_test = op_desc.GetAttr("is_test"); + break; + case OpDescAPI::AttrType::BOOLEAN: + param_.is_test = op_desc.GetAttr("is_test"); + break; + default: + LOG(FATAL) << "Unsupported attribute type: the type of attribute " + "`is_test` in BatchNormOP should be int or bool."; + } + if (op_desc.HasAttr("use_global_stats")) { param_.use_global_stats = op_desc.GetAttr("use_global_stats"); } diff --git a/lite/operators/batch_norm_op_test.cc b/lite/operators/batch_norm_op_test.cc index 574bb4cfd316b05bf08086d865f4eb7de7dd03a3..b79037c0bc9c3e9188eaf0e54b3f958960ab0893 100644 --- a/lite/operators/batch_norm_op_test.cc +++ b/lite/operators/batch_norm_op_test.cc @@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) { desc.SetInput("Mean", {"mean"}); desc.SetInput("Variance", {"variance"}); desc.SetOutput("Y", {"y"}); - desc.SetAttr("is_test", static_cast(1)); + desc.SetAttr("is_test", static_cast(true)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f); @@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) { desc.SetOutput("VarianceOut", {"variance_out"}); desc.SetOutput("SavedMean", {"saved_mean"}); desc.SetOutput("SavedVariance", {"saved_variance"}); - desc.SetAttr("is_test", static_cast(0)); + desc.SetAttr("is_test", static_cast(false)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f);