From ef29b5224bc4588ae2f9bc8787a395faba40f571 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 14 Aug 2017 13:00:36 +0800 Subject: [PATCH] Simplify unit test code --- paddle/framework/op_registry_test.cc | 28 ++++++++-------------------- paddle/framework/operator_test.cc | 24 +++++++----------------- 2 files changed, 15 insertions(+), 37 deletions(-) diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index ec7430a95fa..a52dbf13af7 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -68,11 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp, TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - auto* input = op_desc.add_inputs(); - ConstructVars("input", {"aa"}, input); - - auto* output = op_desc.add_outputs(); - ConstructVars("output", {"bb"}, output); + ConstructVars("input", {"aa"}, op_desc.add_inputs()); + ConstructVars("output", {"bb"}, op_desc.add_outputs()); float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); @@ -92,11 +89,8 @@ TEST(OpRegistry, CreateOp) { TEST(OpRegistry, IllegalAttr) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - auto* input = op_desc.add_inputs(); - ConstructVars("input", {"aa"}, input); - - auto* output = op_desc.add_outputs(); - ConstructVars("output", {"bb"}, output); + ConstructVars("input", {"aa"}, op_desc.add_inputs()); + ConstructVars("output", {"bb"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -120,11 +114,8 @@ TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, DefaultValue) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - auto* input = op_desc.add_inputs(); - ConstructVars("input", {"aa"}, input); - - auto* output = op_desc.add_outputs(); - ConstructVars("output", {"bb"}, output); + ConstructVars("input", {"aa"}, op_desc.add_inputs()); + ConstructVars("output", {"bb"}, op_desc.add_outputs()); ASSERT_TRUE(op_desc.IsInitialized()); @@ -139,11 +130,8 @@ TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); - auto* input = op_desc.add_inputs(); - ConstructVars("input", {"ii"}, input); - - auto* output = op_desc.add_outputs(); - ConstructVars("output", {"oo"}, output); + ConstructVars("input", {"ii"}, op_desc.add_inputs()); + ConstructVars("output", {"oo"}, op_desc.add_outputs()); // attr 'test_attr' is not set bool caught = false; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 46e419a8c88..06abb9d1934 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -71,12 +71,8 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("test_operator"); - - auto* ipt = op_desc.mutable_inputs()->Add(); - ConstructVars("IN1", {"input"}, ipt); - - auto* output = op_desc.mutable_outputs()->Add(); - ConstructVars("OUT1", {"output"}, output); + ConstructVars("IN1", {"input"}, op_desc.add_inputs()); + ConstructVars("OUT1", {"output"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -195,11 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, TEST(OpKernel, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); - auto* ipt = op_desc.mutable_inputs()->Add(); - ConstructVars("IN1", {"x"}, ipt); - - auto* output = op_desc.mutable_outputs()->Add(); - ConstructVars("OUT1", {"y"}, output); + ConstructVars("IN1", {"x"}, op_desc.add_inputs()); + ConstructVars("OUT1", {"y"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -226,12 +219,9 @@ TEST(OpKernel, multi_inputs) { OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); - auto* x = op_desc.mutable_inputs()->Add(); - ConstructVars("xs", {"x0", "x1", "x2"}, x); - auto* k = op_desc.mutable_inputs()->Add(); - ConstructVars("k", {"k0"}, k); - auto* y = op_desc.mutable_outputs()->Add(); - ConstructVars("ys", {"y0", "y1"}, y); + ConstructVars("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); + ConstructVars("k", {"k0"}, op_desc.add_inputs()); + ConstructVars("ys", {"y0", "y1"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); -- GitLab