From 2462d0c5fedb783a322170ff15f828e63b612ead Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 14 Jul 2017 00:50:46 -0500 Subject: [PATCH] Let OpProto support multiple and temporary (#2860) * Let OpProto support multiple and temporary * Each input/output of Paddle's Op could be a list. Add multiple mark to OpProto. Also add a `input_format`/`output_format` attribute if that Op has multiple input or output. The format of that attribute please reference the comments in `op_proto.proto` * Add temporary mark, because some output of an Op is not used by user but used by other op for faster computation. Explicitly mark which output is temporary could let future memory/computation optimization. * Add generated field to AttrProto. * Add `AddInputs`/`AddOutputs` function * It is more readable to invoke `AddInputs` not `AddInput(multiple=true)`. --- paddle/framework/op_proto.proto | 39 +++++++++++ paddle/framework/op_registry.h | 97 +++++++++++++++++++++++++++- paddle/framework/op_registry_test.cc | 15 ++++- 3 files changed, 146 insertions(+), 5 deletions(-) diff --git a/paddle/framework/op_proto.proto b/paddle/framework/op_proto.proto index 22df6f9c6b7..596b8588e78 100644 --- a/paddle/framework/op_proto.proto +++ b/paddle/framework/op_proto.proto @@ -34,6 +34,11 @@ message AttrProto { // Supported attribute comments. It helps 3rd-party language generate doc-string. required string comment = 3; + + // If that attribute is generated, it means the Paddle third language + // binding has responsibility to fill that attribute. End-User should + // not set that attribute. + optional bool generated = 4 [default=false]; } // Input or output message for 3rd-party language binding. @@ -45,6 +50,40 @@ message VarProto { // The comment for that input. It helps 3rd-party language generate doc-string. required string comment = 2; + + // Is that input/output could be a list or not. + // If so, that Op should write a attributed named `input_format` or + // `output_format`. + // + // e.g. + // If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W` + // could be multiple, so the multiple of `X` and `W` is True, and OpDesc + // will hold a attribute of them. + // + // The Op desc of same fc could be + // { + // "type": "fc", + // "input": ["X1", "X2", "W1", "W2", "b"], + // "output": "fc.out", + // "attrs" : { + // "input_format": [0, 2, 4, 5] + // } + // } + // + optional bool multiple = 3 [default=false]; + + // It marks that output is a temporary output. That output is not used by + // user, but used by other op internally as input. If other op is not use + // that output, it could be optimized early. + // + // Attribute temporary_index will be set in OpDesc if there is some + // outputs are temporary. + // + // output = [ "xxx.out1", "xxx.tmp", "xxx.out2"], + // attrs = { + // "temporary_index": [1] + // } + optional bool temporary = 4 [default=false]; } // Op protocol message for 3rd-party language binding. diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 61dfcb70496..d049599a2fd 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" @@ -59,25 +61,52 @@ class OpProtoAndCheckerMaker { OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : proto_(proto), op_checker_(op_checker) {} + ~OpProtoAndCheckerMaker() { CheckNoDuplicatedAttrs(); } + protected: - void AddInput(const std::string& name, const std::string& comment) { + void AddInput(const std::string& name, const std::string& comment, + bool multiple = false) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; + input->set_multiple(multiple); + if (multiple) { + SetHasMultipleInput(); + } + } + + void AddInputs(const std::string& name, const std::string& comment) { + AddInput(name, comment, true); } - void AddOutput(const std::string& name, const std::string& comment) { + void AddOutput(const std::string& name, const std::string& comment, + bool temporary = false, bool multiple = false) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; + output->set_multiple(multiple); + if (multiple) { + SetHasMultipleOutput(); + } + output->set_temporary(temporary); + if (temporary) { + SetHasTemporaryOutput(); + } + } + + void AddOutputs(const std::string& name, const std::string& comment, + bool temporary = false) { + AddOutput(name, comment, temporary, true); } template TypedAttrChecker& AddAttr(const std::string& name, - const std::string& comment) { + const std::string& comment, + bool generated = false) { auto attr = proto_->mutable_attrs()->Add(); *attr->mutable_name() = name; *attr->mutable_comment() = comment; + attr->set_generated(generated); AttrTypeHelper::SetAttrType(attr); return op_checker_->AddAttrChecker(name); } @@ -86,8 +115,70 @@ class OpProtoAndCheckerMaker { *(proto_->mutable_comment()) = comment; } + private: + void SetHasMultiple(const std::string& in_out, bool* flag) { + if (!*flag) { + AddAttr>(in_out + "_format", + "The multiple index of " + in_out + + "\n" + R"DOC( +This attribute is used by Paddle core framework. Paddle's Op support each input +or output could be a list of variable. This attribute is used to show how that +list organized. + +e.g. + input = ["a", "b", "c", "d", "e", "f"] + input_format = [0, 4, 5, 6] + +means + The number of all input variables this op is six, and they are segmented into + three inputs. + + The first input is input[0:4], second is input[4:5], third is input[5:6]. +)DOC", + /*generated*/ true); + *flag = true; + } + } + + void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); } + void SetHasMultipleOutput() { + SetHasMultiple("output", &has_multiple_output_); + } + + void SetHasTemporaryOutput() { + if (!has_temporary_output_) { + AddAttr>("temporary_index", + R"DOC(The temporary index of output. + +Not all output of Paddle Op is used by user. For faster computation, each op +could output some its internal state to other op, other op could take that +output to make compute faster. + +Add a mark to which output is temporary is helpful for future optimization. +)DOC", + /*generated*/ true) + .SetDefault(std::vector()); + has_temporary_output_ = true; + } + } + + void CheckNoDuplicatedAttrs() { + std::unordered_set names; + size_t cnt = 0; + for (auto& attr : proto_->attrs()) { + names.insert(attr.name()); + ++cnt; + } + PADDLE_ENFORCE(names.size() == cnt, + "Cannot register two attribute in same name!"); + } + OpProto* proto_; OpAttrChecker* op_checker_; + bool has_multiple_input_{false}; + bool has_multiple_output_{false}; + bool has_temporary_output_{false}; }; class OpRegistry { diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 9bcc0407add..1adafa3714e 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -36,8 +36,9 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); + AddInputs("input", "input of cosine op"); + AddOutput("output", "output of cosine op", + /*temporary*/ true); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; @@ -117,11 +118,20 @@ TEST(OpRegistry, DefaultValue) { ASSERT_EQ(op->GetAttr("scale"), 1.0); } +static void SetInputFormat(paddle::framework::OpDesc* desc) { + auto attr = desc->add_attrs(); + attr->set_name("input_format"); + attr->set_type(paddle::framework::INTS); + attr->mutable_ints()->Add(0); + attr->mutable_ints()->Add(1); +} + TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); op_desc.add_inputs("ii"); op_desc.add_outputs("oo"); + SetInputFormat(&op_desc); // attr 'test_attr' is not set bool caught = false; @@ -163,6 +173,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); + SetInputFormat(&op_desc); paddle::framework::OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; -- GitLab