From f8e0ba453f63fc805755d6cc16a7778a28243608 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Thu, 13 Feb 2020 09:50:59 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90BUG=20FIX=E3=80=91FIX=20Batch=20norm?= =?UTF-8?q?=20op=20=20(#2846)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lite/operators/batch_norm_op.cc | 15 ++++++++++++++- lite/operators/batch_norm_op_test.cc | 4 ++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/lite/operators/batch_norm_op.cc b/lite/operators/batch_norm_op.cc index 6faa9eb225..76c257c6d3 100644 --- a/lite/operators/batch_norm_op.cc +++ b/lite/operators/batch_norm_op.cc @@ -82,7 +82,20 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.variance = scope->FindVar(op_desc.Input("Variance").front())->GetMutable(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); - param_.is_test = op_desc.GetAttr("is_test"); + + auto is_test_type = op_desc.GetAttrType("is_test"); + switch (is_test_type) { + case OpDescAPI::AttrType::INT: + param_.is_test = op_desc.GetAttr("is_test"); + break; + case OpDescAPI::AttrType::BOOLEAN: + param_.is_test = op_desc.GetAttr("is_test"); + break; + default: + LOG(FATAL) << "Unsupported attribute type: the type of attribute " + "`is_test` in BatchNormOP should be int or bool."; + } + if (op_desc.HasAttr("use_global_stats")) { param_.use_global_stats = op_desc.GetAttr("use_global_stats"); } diff --git a/lite/operators/batch_norm_op_test.cc b/lite/operators/batch_norm_op_test.cc index 574bb4cfd3..b79037c0bc 100644 --- a/lite/operators/batch_norm_op_test.cc +++ b/lite/operators/batch_norm_op_test.cc @@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) { desc.SetInput("Mean", {"mean"}); desc.SetInput("Variance", {"variance"}); desc.SetOutput("Y", {"y"}); - desc.SetAttr("is_test", static_cast(1)); + desc.SetAttr("is_test", static_cast(true)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f); @@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) { desc.SetOutput("VarianceOut", {"variance_out"}); desc.SetOutput("SavedMean", {"saved_mean"}); desc.SetOutput("SavedVariance", {"saved_variance"}); - desc.SetAttr("is_test", static_cast(0)); + desc.SetAttr("is_test", static_cast(false)); desc.SetAttr("use_global_stats", false); desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("momentum", 0.9f); -- GitLab