未验证 提交 f8e0ba45 编写于 作者: H huzhiqiang 提交者: GitHub

【BUG FIX】FIX Batch norm op (#2846)

上级 b38753da
...@@ -82,7 +82,20 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -82,7 +82,20 @@ bool BatchNormOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.variance = param_.variance =
scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>(); scope->FindVar(op_desc.Input("Variance").front())->GetMutable<Tensor>();
param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>(); param_.y = scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
auto is_test_type = op_desc.GetAttrType("is_test");
switch (is_test_type) {
case OpDescAPI::AttrType::INT:
param_.is_test = op_desc.GetAttr<int>("is_test"); param_.is_test = op_desc.GetAttr<int>("is_test");
break;
case OpDescAPI::AttrType::BOOLEAN:
param_.is_test = op_desc.GetAttr<bool>("is_test");
break;
default:
LOG(FATAL) << "Unsupported attribute type: the type of attribute "
"`is_test` in BatchNormOP should be int or bool.";
}
if (op_desc.HasAttr("use_global_stats")) { if (op_desc.HasAttr("use_global_stats")) {
param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats"); param_.use_global_stats = op_desc.GetAttr<bool>("use_global_stats");
} }
......
...@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) { ...@@ -46,7 +46,7 @@ TEST(batch_norm_op_lite, test) {
desc.SetInput("Mean", {"mean"}); desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"}); desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"}); desc.SetOutput("Y", {"y"});
desc.SetAttr("is_test", static_cast<int>(1)); desc.SetAttr("is_test", static_cast<bool>(true));
desc.SetAttr("use_global_stats", false); desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f); desc.SetAttr("momentum", 0.9f);
...@@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) { ...@@ -101,7 +101,7 @@ TEST(batch_norm_op_lite, test_enable_is_test) {
desc.SetOutput("VarianceOut", {"variance_out"}); desc.SetOutput("VarianceOut", {"variance_out"});
desc.SetOutput("SavedMean", {"saved_mean"}); desc.SetOutput("SavedMean", {"saved_mean"});
desc.SetOutput("SavedVariance", {"saved_variance"}); desc.SetOutput("SavedVariance", {"saved_variance"});
desc.SetAttr("is_test", static_cast<int>(0)); desc.SetAttr("is_test", static_cast<bool>(false));
desc.SetAttr("use_global_stats", false); desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f); desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f); desc.SetAttr("momentum", 0.9f);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册