diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 1cd2f904988ce59a09ca6c947f0facbd2d820365..0394fa1d76502400918f3144a3c494374d577345 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -5,20 +5,49 @@ package( default_visibility = ["//visibility:public"], ) - licenses(["notice"]) # Apache 2.0 load("//mace:mace.bzl", "if_android") +cc_library( + name = "test", + testonly = 1, + srcs = [ + "ops_test_util.h", + ], + deps = [ + "//mace/core", + "@gtest//:gtest", + ], +) + cc_library( name = "ops", - srcs = glob(["*.cc"]), - hdrs = glob(["*.h"]), + srcs = [ + "batch_norm.cc", + "relu.cc", + ], + hdrs = glob([ + "relu.h", + "batch_norm.h", + ]), + copts = ["-std=c++11"], deps = [ + "//mace/core", + "//mace/kernels", "//mace/proto:cc_proto", - "//mace/core:core", - "//mace/kernels:kernels", ], - copts = ['-std=c++11'], alwayslink = 1, ) + +cc_test( + name = "batch_norm_test", + srcs = ["batch_norm_test.cc"], + copts = ["-std=c++11"], + linkstatic = 1, + deps = [ + ":ops", + ":test", + "@gtest//:gtest_main", + ], +) diff --git a/mace/ops/batch_norm.h b/mace/ops/batch_norm.h index a2e175a78b73d31288385211fabb2a221536c6b8..2b4fad42049b629916aa5a78fac860ceb6f36560 100644 --- a/mace/ops/batch_norm.h +++ b/mace/ops/batch_norm.h @@ -25,11 +25,11 @@ class BatchNormOp : public Operator { const float variance_epsilon = this->template GetSingleArgument("variance_epsilon", 1e-4); - REQUIRE(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size()); - REQUIRE(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size()); - REQUIRE(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size()); - REQUIRE(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size()); - REQUIRE(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size()); + MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size()); + MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size()); + MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size()); + MACE_CHECK(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size()); + MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size()); Tensor* output = this->Output(0); output->ResizeLike(input); diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 82b7063da6e8c3c35bb6613b9c02d0f6d2ab6671..61085f7dd0ed090fb248db3f76037199d7538c78 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -27,7 +27,7 @@ class OpDefBuilder { return *this; } void Finalize(OperatorDef* op_def) const { - REQUIRE(op_def != NULL, "input should not be null."); + MACE_CHECK(op_def != NULL, "input should not be null."); *op_def = op_def_; } OperatorDef op_def_;