提交 9aabdccc 编写于 作者: L liuqi

Add cpu test util to simplify the unit test.

上级 196fd847
...@@ -5,20 +5,49 @@ package( ...@@ -5,20 +5,49 @@ package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
) )
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android") load("//mace:mace.bzl", "if_android")
cc_library(
name = "test",
testonly = 1,
srcs = [
"ops_test_util.h",
],
deps = [
"//mace/core",
"@gtest//:gtest",
],
)
cc_library( cc_library(
name = "ops", name = "ops",
srcs = glob(["*.cc"]), srcs = [
hdrs = glob(["*.h"]), "batch_norm.cc",
"relu.cc",
],
hdrs = glob([
"relu.h",
"batch_norm.h",
]),
copts = ["-std=c++11"],
deps = [ deps = [
"//mace/core",
"//mace/kernels",
"//mace/proto:cc_proto", "//mace/proto:cc_proto",
"//mace/core:core",
"//mace/kernels:kernels",
], ],
copts = ['-std=c++11'],
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "batch_norm_test",
srcs = ["batch_norm_test.cc"],
copts = ["-std=c++11"],
linkstatic = 1,
deps = [
":ops",
":test",
"@gtest//:gtest_main",
],
)
...@@ -25,11 +25,11 @@ class BatchNormOp : public Operator<D, T> { ...@@ -25,11 +25,11 @@ class BatchNormOp : public Operator<D, T> {
const float variance_epsilon = this->template GetSingleArgument<float>("variance_epsilon", 1e-4); const float variance_epsilon = this->template GetSingleArgument<float>("variance_epsilon", 1e-4);
REQUIRE(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size()); MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", input->dim_size());
REQUIRE(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size()); MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ", scale->dim_size());
REQUIRE(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size()); MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ", offset->dim_size());
REQUIRE(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size()); MACE_CHECK(mean->dim_size() == 1, "mean must be 1-dimensional. ", mean->dim_size());
REQUIRE(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size()); MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ", var->dim_size());
Tensor* output = this->Output(0); Tensor* output = this->Output(0);
output->ResizeLike(input); output->ResizeLike(input);
......
...@@ -27,7 +27,7 @@ class OpDefBuilder { ...@@ -27,7 +27,7 @@ class OpDefBuilder {
return *this; return *this;
} }
void Finalize(OperatorDef* op_def) const { void Finalize(OperatorDef* op_def) const {
REQUIRE(op_def != NULL, "input should not be null."); MACE_CHECK(op_def != NULL, "input should not be null.");
*op_def = op_def_; *op_def = op_def_;
} }
OperatorDef op_def_; OperatorDef op_def_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册