提交 f09cb657 编写于 作者: Y Yu Yang

Follow comments from WangYi

上级 ef29b522
......@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NoGradient();
AddInput("b", "Bias of Add").NoGradient();
AddOutput("Out", "Out of Add").NoGradient();
AddInput("X", "Input X of Add").AsNoGradient();
AddInput("b", "Bias of Add").AsNoGradient();
AddOutput("Out", "Out of Add").AsNoGradient();
AddComment("Add Op");
}
};
......@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("mul_result", "").SetIntermediate();
AddOutput("add_result", "").SetIntermediate();
AddOutput("mul_result", "").AsIntermediate();
AddOutput("add_result", "").AsIntermediate();
AddOutput("Out", "");
AddComment("");
}
......@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetDuplicable();
AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y");
AddComment("");
}
......
......@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetDuplicable();
AddInput("In2_mult", "a multiple input").AsDuplicable();
AddInput("In3", "another single input");
AddOutput("Out1", "a single output");
AddOutput("Out2_mult", "a multiple output").SetDuplicable();
AddOutput("Out2_mult", "a multiple output").AsDuplicable();
AddComment("test op with multiple inputs and outputs");
}
};
......@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetDuplicable().NoGradient();
AddInput("In3_mult", "another multiple input").SetDuplicable();
AddOutput("Out1_mult", "a multiple output").SetDuplicable();
AddOutput("Out2", "a single output").NoGradient();
AddInput("In2_mult", "a multiple input").AsDuplicable().AsNoGradient();
AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").AsNoGradient();
AddComment("op with inputs and outputs ignored in gradient calculating");
}
};
......
......@@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker {
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& SetDuplicable() {
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& SetIntermediate() {
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NoGradient() {
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this;
}
......
......@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").SetDuplicable();
AddOutput("output", "output of cosine op").SetIntermediate();
AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
};
......@@ -51,12 +51,12 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework
} // namespace paddle
static void ConstructVars(const std::string& param_name,
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
var->add_arguments(arg_name);
}
}
......@@ -68,8 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
ConstructVars("input", {"aa"}, op_desc.add_inputs());
ConstructVars("output", {"bb"}, op_desc.add_outputs());
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add();
......@@ -89,8 +89,8 @@ TEST(OpRegistry, CreateOp) {
TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
ConstructVars("input", {"aa"}, op_desc.add_inputs());
ConstructVars("output", {"bb"}, op_desc.add_outputs());
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -114,8 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
ConstructVars("input", {"aa"}, op_desc.add_inputs());
ConstructVars("output", {"bb"}, op_desc.add_outputs());
BuildVar("input", {"aa"}, op_desc.add_inputs());
BuildVar("output", {"bb"}, op_desc.add_outputs());
ASSERT_TRUE(op_desc.IsInitialized());
......@@ -130,8 +130,8 @@ TEST(OpRegistry, DefaultValue) {
TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
ConstructVars("input", {"ii"}, op_desc.add_inputs());
ConstructVars("output", {"oo"}, op_desc.add_outputs());
BuildVar("input", {"ii"}, op_desc.add_inputs());
BuildVar("output", {"oo"}, op_desc.add_outputs());
// attr 'test_attr' is not set
bool caught = false;
......
......@@ -56,7 +56,7 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework
} // namespace paddle
static void ConstructVars(const std::string& param_name,
static void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
......@@ -71,8 +71,8 @@ REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator");
ConstructVars("IN1", {"input"}, op_desc.add_inputs());
ConstructVars("OUT1", {"output"}, op_desc.add_outputs());
BuildVar("IN1", {"input"}, op_desc.add_inputs());
BuildVar("OUT1", {"output"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -132,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("xs", "inputs of test op").SetDuplicable();
AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").SetDuplicable();
AddOutput("ys", "outputs of test op").AsDuplicable();
AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0)
.LargerThan(0.0);
......@@ -191,8 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel");
ConstructVars("IN1", {"x"}, op_desc.add_inputs());
ConstructVars("OUT1", {"y"}, op_desc.add_outputs());
BuildVar("IN1", {"x"}, op_desc.add_inputs());
BuildVar("OUT1", {"y"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......@@ -219,9 +219,9 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel");
ConstructVars("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
ConstructVars("k", {"k0"}, op_desc.add_inputs());
ConstructVars("ys", {"y0", "y1"}, op_desc.add_outputs());
BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
BuildVar("k", {"k0"}, op_desc.add_inputs());
BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
......
......@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").NoGradient();
AddOutput("Out", "The output of mean op").AsNoGradient();
AddComment("Mean Operator");
}
};
......
......@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
// inputs and outputs stored in proto
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.SetDuplicable();
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.SetDuplicable();
.AsDuplicable();
AddInput(name.step_net, "network shared by all steps.");
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.SetDuplicable();
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册