提交 89c48e60 编写于 作者: H hong19860320

remove the batch_norm unit test which the attr 'is_test' is enabled

test=develop
上级 add63d23
......@@ -67,72 +67,72 @@ TEST(batch_norm_op_lite, test) {
}
}
TEST(batch_norm_op_lite, test_enable_is_test) {
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* scale = scope.Var("scale")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* mean = scope.Var("mean")->GetMutable<Tensor>();
auto* variance = scope.Var("variance")->GetMutable<Tensor>();
auto* y = scope.Var("y")->GetMutable<Tensor>();
auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
x->Resize({2, 32, 10, 20});
auto x_dims = x->dims();
const int64_t channel_size = x_dims[1]; // NCHW
scale->Resize({channel_size});
bias->Resize({channel_size});
mean->Resize({channel_size});
variance->Resize({channel_size});
// TEST(batch_norm_op_lite, test_enable_is_test) {
// // prepare variables
// Scope scope;
// auto* x = scope.Var("x")->GetMutable<Tensor>();
// auto* scale = scope.Var("scale")->GetMutable<Tensor>();
// auto* bias = scope.Var("bias")->GetMutable<Tensor>();
// auto* mean = scope.Var("mean")->GetMutable<Tensor>();
// auto* variance = scope.Var("variance")->GetMutable<Tensor>();
// auto* y = scope.Var("y")->GetMutable<Tensor>();
// auto* mean_out = scope.Var("mean_out")->GetMutable<Tensor>();
// auto* variance_out = scope.Var("variance_out")->GetMutable<Tensor>();
// auto* saved_mean = scope.Var("saved_mean")->GetMutable<Tensor>();
// auto* saved_variance = scope.Var("saved_variance")->GetMutable<Tensor>();
// x->Resize({2, 32, 10, 20});
// auto x_dims = x->dims();
// const int64_t channel_size = x_dims[1]; // NCHW
// scale->Resize({channel_size});
// bias->Resize({channel_size});
// mean->Resize({channel_size});
// variance->Resize({channel_size});
// prepare op desc
cpp::OpDesc desc;
desc.SetType("batch_norm");
desc.SetInput("X", {"x"});
desc.SetInput("Scale", {"scale"});
desc.SetInput("Bias", {"bias"});
desc.SetInput("Mean", {"mean"});
desc.SetInput("Variance", {"variance"});
desc.SetOutput("Y", {"y"});
desc.SetOutput("MeanOut", {"mean_out"});
desc.SetOutput("VarianceOut", {"variance_out"});
desc.SetOutput("SavedMean", {"saved_mean"});
desc.SetOutput("SavedVariance", {"saved_variance"});
desc.SetAttr("is_test", false);
desc.SetAttr("use_global_stats", false);
desc.SetAttr("epsilon", 1e-5f);
desc.SetAttr("momentum", 0.9f);
desc.SetAttr("data_layout", std::string("NCHW"));
// // prepare op desc
// cpp::OpDesc desc;
// desc.SetType("batch_norm");
// desc.SetInput("X", {"x"});
// desc.SetInput("Scale", {"scale"});
// desc.SetInput("Bias", {"bias"});
// desc.SetInput("Mean", {"mean"});
// desc.SetInput("Variance", {"variance"});
// desc.SetOutput("Y", {"y"});
// desc.SetOutput("MeanOut", {"mean_out"});
// desc.SetOutput("VarianceOut", {"variance_out"});
// desc.SetOutput("SavedMean", {"saved_mean"});
// desc.SetOutput("SavedVariance", {"saved_variance"});
// desc.SetAttr("is_test", false);
// desc.SetAttr("use_global_stats", false);
// desc.SetAttr("epsilon", 1e-5f);
// desc.SetAttr("momentum", 0.9f);
// desc.SetAttr("data_layout", std::string("NCHW"));
BatchNormOp batch_norm("batch_norm");
// BatchNormOp batch_norm("batch_norm");
batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
batch_norm.Attach(desc, &scope);
batch_norm.CheckShape();
batch_norm.InferShape();
// batch_norm.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
// batch_norm.Attach(desc, &scope);
// batch_norm.CheckShape();
// batch_norm.InferShape();
// check output dims
auto y_dims = y->dims();
CHECK_EQ(y_dims.size(), x_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
CHECK_EQ(y_dims[i], x_dims[i]);
}
auto mean_out_dims = mean_out->dims();
auto variance_out_dims = variance_out->dims();
auto saved_mean_dims = saved_mean->dims();
auto saved_variance_dims = saved_variance->dims();
CHECK_EQ(mean_out_dims.size(), 1UL);
CHECK_EQ(variance_out_dims.size(), 1UL);
CHECK_EQ(saved_mean_dims.size(), 1UL);
CHECK_EQ(saved_variance_dims.size(), 1UL);
CHECK_EQ(mean_out_dims[0], channel_size);
CHECK_EQ(variance_out_dims[0], channel_size);
CHECK_EQ(saved_mean_dims[0], channel_size);
CHECK_EQ(saved_variance_dims[0], channel_size);
}
// // check output dims
// auto y_dims = y->dims();
// CHECK_EQ(y_dims.size(), x_dims.size());
// for (size_t i = 0; i < y_dims.size(); i++) {
// CHECK_EQ(y_dims[i], x_dims[i]);
// }
// auto mean_out_dims = mean_out->dims();
// auto variance_out_dims = variance_out->dims();
// auto saved_mean_dims = saved_mean->dims();
// auto saved_variance_dims = saved_variance->dims();
// CHECK_EQ(mean_out_dims.size(), 1UL);
// CHECK_EQ(variance_out_dims.size(), 1UL);
// CHECK_EQ(saved_mean_dims.size(), 1UL);
// CHECK_EQ(saved_variance_dims.size(), 1UL);
// CHECK_EQ(mean_out_dims[0], channel_size);
// CHECK_EQ(variance_out_dims[0], channel_size);
// CHECK_EQ(saved_mean_dims[0], channel_size);
// CHECK_EQ(saved_variance_dims[0], channel_size);
// }
} // namespace operators
} // namespace lite
......
......@@ -69,7 +69,7 @@ TEST(pool_op_lite, test) {
bool use_quantizer{false};
desc.SetAttr("use_quantizer", use_quantizer);
PoolOpLite pool("pool");
PoolOpLite pool("pool2d");
pool.SetValidPlaces({Place{TARGET(kARM), PRECISION(kFloat)}});
pool.Attach(desc, &scope);
auto kernels = pool.CreateKernels({Place{TARGET(kARM), PRECISION(kFloat)}});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册