From f09cb657e618aaed68d74ed87ae5599fb6136313 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 14 Aug 2017 13:51:47 +0800 Subject: [PATCH] Follow comments from WangYi --- paddle/framework/backward_test.cc | 12 +++++----- paddle/framework/grad_op_builder_test.cc | 12 +++++----- paddle/framework/op_registry.h | 9 +++++--- paddle/framework/op_registry_test.cc | 28 ++++++++++++------------ paddle/framework/operator_test.cc | 24 ++++++++++---------- paddle/operators/mean_op.cc | 2 +- paddle/operators/recurrent_op.cc | 6 ++--- 7 files changed, 48 insertions(+), 45 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index d6ba1f7d63..e1e5379009 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").NoGradient(); - AddInput("b", "Bias of Add").NoGradient(); - AddOutput("Out", "Out of Add").NoGradient(); + AddInput("X", "Input X of Add").AsNoGradient(); + AddInput("b", "Bias of Add").AsNoGradient(); + AddOutput("Out", "Out of Add").AsNoGradient(); AddComment("Add Op"); } }; @@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "x"); AddInput("W", "w"); AddInput("b", "b"); - AddOutput("mul_result", "").SetIntermediate(); - AddOutput("add_result", "").SetIntermediate(); + AddOutput("mul_result", "").AsIntermediate(); + AddOutput("add_result", "").AsIntermediate(); AddOutput("Out", ""); AddComment(""); } @@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { public: AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x").SetDuplicable(); + AddInput("X", "x").AsDuplicable(); AddOutput("Y", "y"); AddComment(""); } diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 210e07942b..75c6ec8b56 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker { MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").SetDuplicable(); + AddInput("In2_mult", "a multiple input").AsDuplicable(); AddInput("In3", "another single input"); AddOutput("Out1", "a single output"); - AddOutput("Out2_mult", "a multiple output").SetDuplicable(); + AddOutput("Out2_mult", "a multiple output").AsDuplicable(); AddComment("test op with multiple inputs and outputs"); } }; @@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").SetDuplicable().NoGradient(); - AddInput("In3_mult", "another multiple input").SetDuplicable(); - AddOutput("Out1_mult", "a multiple output").SetDuplicable(); - AddOutput("Out2", "a single output").NoGradient(); + AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient(); + AddInput("In3_mult", "another multiple input").AsDuplicable(); + AddOutput("Out1_mult", "a multiple output").AsDuplicable(); + AddOutput("Out2", "a single output").AsNoGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index d840c1c4e0..e93ee14425 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker { struct VariableBuilder { OpProto::Var* var_; - VariableBuilder& SetDuplicable() { + VariableBuilder& AsDuplicable() { var_->set_duplicable(true); return *this; } - VariableBuilder& SetIntermediate() { + VariableBuilder& AsIntermediate() { var_->set_intermediate(true); return *this; } - VariableBuilder& NoGradient() { + // TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it + // means that input/output is not needed when calculate gradient. It does + // not mean no gradient when backward. It should be changed soon. + VariableBuilder& AsNoGradient() { var_->set_no_gradient(true); return *this; } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index a52dbf13af..17cbd8563c 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op").SetDuplicable(); - AddOutput("output", "output of cosine op").SetIntermediate(); + AddInput("input", "input of cosine op").AsDuplicable(); + AddOutput("output", "output of cosine op").AsIntermediate(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; @@ -51,12 +51,12 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle -static void ConstructVars(const std::string& param_name, - std::initializer_list arguments, - paddle::framework::OpDesc::Var* var) { +static void BuildVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { var->set_parameter(param_name); for (auto& arg_name : arguments) { - *var->mutable_arguments()->Add() = arg_name; + var->add_arguments(arg_name); } } @@ -68,8 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp, TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - ConstructVars("input", {"aa"}, op_desc.add_inputs()); - ConstructVars("output", {"bb"}, op_desc.add_outputs()); + BuildVar("input", {"aa"}, op_desc.add_inputs()); + BuildVar("output", {"bb"}, op_desc.add_outputs()); float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); @@ -89,8 +89,8 @@ TEST(OpRegistry, CreateOp) { TEST(OpRegistry, IllegalAttr) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - ConstructVars("input", {"aa"}, op_desc.add_inputs()); - ConstructVars("output", {"bb"}, op_desc.add_outputs()); + BuildVar("input", {"aa"}, op_desc.add_inputs()); + BuildVar("output", {"bb"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -114,8 +114,8 @@ TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, DefaultValue) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - ConstructVars("input", {"aa"}, op_desc.add_inputs()); - ConstructVars("output", {"bb"}, op_desc.add_outputs()); + BuildVar("input", {"aa"}, op_desc.add_inputs()); + BuildVar("output", {"bb"}, op_desc.add_outputs()); ASSERT_TRUE(op_desc.IsInitialized()); @@ -130,8 +130,8 @@ TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); - ConstructVars("input", {"ii"}, op_desc.add_inputs()); - ConstructVars("output", {"oo"}, op_desc.add_outputs()); + BuildVar("input", {"ii"}, op_desc.add_inputs()); + BuildVar("output", {"oo"}, op_desc.add_outputs()); // attr 'test_attr' is not set bool caught = false; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 06abb9d193..5e0280d4fa 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -56,9 +56,9 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle -static void ConstructVars(const std::string& param_name, - std::initializer_list arguments, - paddle::framework::OpDesc::Var* var) { +static void BuildVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { var->set_parameter(param_name); for (auto& arg_name : arguments) { *var->mutable_arguments()->Add() = arg_name; @@ -71,8 +71,8 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("test_operator"); - ConstructVars("IN1", {"input"}, op_desc.add_inputs()); - ConstructVars("OUT1", {"output"}, op_desc.add_outputs()); + BuildVar("IN1", {"input"}, op_desc.add_inputs()); + BuildVar("OUT1", {"output"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -132,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("xs", "inputs of test op").SetDuplicable(); + AddInput("xs", "inputs of test op").AsDuplicable(); AddInput("k", "input of test op"); - AddOutput("ys", "outputs of test op").SetDuplicable(); + AddOutput("ys", "outputs of test op").AsDuplicable(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); @@ -191,8 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, TEST(OpKernel, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); - ConstructVars("IN1", {"x"}, op_desc.add_inputs()); - ConstructVars("OUT1", {"y"}, op_desc.add_outputs()); + BuildVar("IN1", {"x"}, op_desc.add_inputs()); + BuildVar("OUT1", {"y"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -219,9 +219,9 @@ TEST(OpKernel, multi_inputs) { OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); - ConstructVars("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); - ConstructVars("k", {"k0"}, op_desc.add_inputs()); - ConstructVars("ys", {"y0", "y1"}, op_desc.add_outputs()); + BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); + BuildVar("k", {"k0"}, op_desc.add_inputs()); + BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 6e28c294b1..3b258a6bd0 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").NoGradient(); + AddOutput("Out", "The output of mean op").AsNoGradient(); AddComment("Mean Operator"); } }; diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index ff02b69276..5e6ba6b8dd 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker // inputs and outputs stored in proto AddInput(name.inlinks, "the inputs that need to be segmented for each step.") - .SetDuplicable(); + .AsDuplicable(); AddInput(name.boot_memories, "variables to initialize memories.") - .SetDuplicable(); + .AsDuplicable(); AddInput(name.step_net, "network shared by all steps."); AddOutput(name.outlinks, "the outputs that need to concated for all steps.") - .SetDuplicable(); + .AsDuplicable(); AddOutput(name.step_scopes, "step scopes"); // Attributes stored in AttributeMap -- GitLab