提交 f09cb657 编写于 作者: Y Yu Yang

Follow comments from WangYi

上级 ef29b522
...@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { ...@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NoGradient(); AddInput("X", "Input X of Add").AsNoGradient();
AddInput("b", "Bias of Add").NoGradient(); AddInput("b", "Bias of Add").AsNoGradient();
AddOutput("Out", "Out of Add").NoGradient(); AddOutput("Out", "Out of Add").AsNoGradient();
AddComment("Add Op"); AddComment("Add Op");
} }
}; };
...@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { ...@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x"); AddInput("X", "x");
AddInput("W", "w"); AddInput("W", "w");
AddInput("b", "b"); AddInput("b", "b");
AddOutput("mul_result", "").SetIntermediate(); AddOutput("mul_result", "").AsIntermediate();
AddOutput("add_result", "").SetIntermediate(); AddOutput("add_result", "").AsIntermediate();
AddOutput("Out", ""); AddOutput("Out", "");
AddComment(""); AddComment("");
} }
...@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public: public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetDuplicable(); AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y"); AddOutput("Y", "y");
AddComment(""); AddComment("");
} }
......
...@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker { ...@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input"); AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetDuplicable(); AddInput("In2_mult", "a multiple input").AsDuplicable();
AddInput("In3", "another single input"); AddInput("In3", "another single input");
AddOutput("Out1", "a single output"); AddOutput("Out1", "a single output");
AddOutput("Out2_mult", "a multiple output").SetDuplicable(); AddOutput("Out2_mult", "a multiple output").AsDuplicable();
AddComment("test op with multiple inputs and outputs"); AddComment("test op with multiple inputs and outputs");
} }
}; };
...@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { ...@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input"); AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetDuplicable().NoGradient(); AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
AddInput("In3_mult", "another multiple input").SetDuplicable(); AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").SetDuplicable(); AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").NoGradient(); AddOutput("Out2", "a single output").AsNoGradient();
AddComment("op with inputs and outputs ignored in gradient calculating"); AddComment("op with inputs and outputs ignored in gradient calculating");
} }
}; };
......
...@@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker { ...@@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker {
struct VariableBuilder { struct VariableBuilder {
OpProto::Var* var_; OpProto::Var* var_;
VariableBuilder& SetDuplicable() { VariableBuilder& AsDuplicable() {
var_->set_duplicable(true); var_->set_duplicable(true);
return *this; return *this;
} }
VariableBuilder& SetIntermediate() { VariableBuilder& AsIntermediate() {
var_->set_intermediate(true); var_->set_intermediate(true);
return *this; return *this;
} }
VariableBuilder& NoGradient() { // TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true); var_->set_no_gradient(true);
return *this; return *this;
} }
......
...@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").SetDuplicable(); AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").SetIntermediate(); AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
}; };
...@@ -51,12 +51,12 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -51,12 +51,12 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
static void ConstructVars(const std::string& param_name, static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) { paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name); var->set_parameter(param_name);
for (auto& arg_name : arguments) { for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name; var->add_arguments(arg_name);
} }
} }
...@@ -68,8 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp, ...@@ -68,8 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
ConstructVars("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
ConstructVars("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
float scale = 3.3; float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
...@@ -89,8 +89,8 @@ TEST(OpRegistry, CreateOp) { ...@@ -89,8 +89,8 @@ TEST(OpRegistry, CreateOp) {
TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
ConstructVars("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
ConstructVars("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -114,8 +114,8 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -114,8 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
ConstructVars("input", {"aa"}, op_desc.add_inputs()); BuildVar("input", {"aa"}, op_desc.add_inputs());
ConstructVars("output", {"bb"}, op_desc.add_outputs()); BuildVar("output", {"bb"}, op_desc.add_outputs());
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
...@@ -130,8 +130,8 @@ TEST(OpRegistry, DefaultValue) { ...@@ -130,8 +130,8 @@ TEST(OpRegistry, DefaultValue) {
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
ConstructVars("input", {"ii"}, op_desc.add_inputs()); BuildVar("input", {"ii"}, op_desc.add_inputs());
ConstructVars("output", {"oo"}, op_desc.add_outputs()); BuildVar("output", {"oo"}, op_desc.add_outputs());
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
......
...@@ -56,9 +56,9 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -56,9 +56,9 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
static void ConstructVars(const std::string& param_name, static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments, std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) { paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name); var->set_parameter(param_name);
for (auto& arg_name : arguments) { for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name; *var->mutable_arguments()->Add() = arg_name;
...@@ -71,8 +71,8 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, ...@@ -71,8 +71,8 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
ConstructVars("IN1", {"input"}, op_desc.add_inputs()); BuildVar("IN1", {"input"}, op_desc.add_inputs());
ConstructVars("OUT1", {"output"}, op_desc.add_outputs()); BuildVar("OUT1", {"output"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -132,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -132,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("xs", "inputs of test op").SetDuplicable(); AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op"); AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").SetDuplicable(); AddOutput("ys", "outputs of test op").AsDuplicable();
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
...@@ -191,8 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, ...@@ -191,8 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
ConstructVars("IN1", {"x"}, op_desc.add_inputs()); BuildVar("IN1", {"x"}, op_desc.add_inputs());
ConstructVars("OUT1", {"y"}, op_desc.add_outputs()); BuildVar("OUT1", {"y"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -219,9 +219,9 @@ TEST(OpKernel, multi_inputs) { ...@@ -219,9 +219,9 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc; OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel"); op_desc.set_type("op_multi_inputs_with_kernel");
ConstructVars("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
ConstructVars("k", {"k0"}, op_desc.add_inputs()); BuildVar("k", {"k0"}, op_desc.add_inputs());
ConstructVars("ys", {"y0", "y1"}, op_desc.add_outputs()); BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
......
...@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NoGradient(); AddOutput("Out", "The output of mean op").AsNoGradient();
AddComment("Mean Operator"); AddComment("Mean Operator");
} }
}; };
......
...@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker ...@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
// inputs and outputs stored in proto // inputs and outputs stored in proto
AddInput(name.inlinks, AddInput(name.inlinks,
"the inputs that need to be segmented for each step.") "the inputs that need to be segmented for each step.")
.SetDuplicable(); .AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.") AddInput(name.boot_memories, "variables to initialize memories.")
.SetDuplicable(); .AsDuplicable();
AddInput(name.step_net, "network shared by all steps."); AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the outputs that need to concated for all steps.") AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.SetDuplicable(); .AsDuplicable();
AddOutput(name.step_scopes, "step scopes"); AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap // Attributes stored in AttributeMap
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册