提交 2462d0c5 编写于 作者: Y Yu Yang 提交者: Qiao Longfei

Let OpProto support multiple and temporary (#2860)

* Let OpProto support multiple and temporary

* Each input/output of Paddle's Op could be a list. Add multiple mark to
  OpProto. Also add a `input_format`/`output_format` attribute if that
  Op has multiple input or output. The format of that attribute please
  reference the comments in `op_proto.proto`
* Add temporary mark, because some output of an Op is not used by user
  but used by other op for faster computation. Explicitly mark which
  output is temporary could let future memory/computation optimization.
* Add generated field to AttrProto.

* Add `AddInputs`/`AddOutputs` function

* It is more readable to invoke `AddInputs` not
  `AddInput(multiple=true)`.
上级 f49fda5e
...@@ -34,6 +34,11 @@ message AttrProto { ...@@ -34,6 +34,11 @@ message AttrProto {
// Supported attribute comments. It helps 3rd-party language generate doc-string. // Supported attribute comments. It helps 3rd-party language generate doc-string.
required string comment = 3; required string comment = 3;
// If that attribute is generated, it means the Paddle third language
// binding has responsibility to fill that attribute. End-User should
// not set that attribute.
optional bool generated = 4 [default=false];
} }
// Input or output message for 3rd-party language binding. // Input or output message for 3rd-party language binding.
...@@ -45,6 +50,40 @@ message VarProto { ...@@ -45,6 +50,40 @@ message VarProto {
// The comment for that input. It helps 3rd-party language generate doc-string. // The comment for that input. It helps 3rd-party language generate doc-string.
required string comment = 2; required string comment = 2;
// Is that input/output could be a list or not.
// If so, that Op should write a attributed named `input_format` or
// `output_format`.
//
// e.g.
// If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W`
// could be multiple, so the multiple of `X` and `W` is True, and OpDesc
// will hold a attribute of them.
//
// The Op desc of same fc could be
// {
// "type": "fc",
// "input": ["X1", "X2", "W1", "W2", "b"],
// "output": "fc.out",
// "attrs" : {
// "input_format": [0, 2, 4, 5]
// }
// }
//
optional bool multiple = 3 [default=false];
// It marks that output is a temporary output. That output is not used by
// user, but used by other op internally as input. If other op is not use
// that output, it could be optimized early.
//
// Attribute temporary_index will be set in OpDesc if there is some
// outputs are temporary.
//
// output = [ "xxx.out1", "xxx.tmp", "xxx.out2"],
// attrs = {
// "temporary_index": [1]
// }
optional bool temporary = 4 [default=false];
} }
// Op protocol message for 3rd-party language binding. // Op protocol message for 3rd-party language binding.
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <algorithm> #include <algorithm>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
...@@ -59,25 +61,52 @@ class OpProtoAndCheckerMaker { ...@@ -59,25 +61,52 @@ class OpProtoAndCheckerMaker {
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {} : proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() { CheckNoDuplicatedAttrs(); }
protected: protected:
void AddInput(const std::string& name, const std::string& comment) { void AddInput(const std::string& name, const std::string& comment,
bool multiple = false) {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name; *input->mutable_name() = name;
*input->mutable_comment() = comment; *input->mutable_comment() = comment;
input->set_multiple(multiple);
if (multiple) {
SetHasMultipleInput();
}
}
void AddInputs(const std::string& name, const std::string& comment) {
AddInput(name, comment, true);
} }
void AddOutput(const std::string& name, const std::string& comment) { void AddOutput(const std::string& name, const std::string& comment,
bool temporary = false, bool multiple = false) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name; *output->mutable_name() = name;
*output->mutable_comment() = comment; *output->mutable_comment() = comment;
output->set_multiple(multiple);
if (multiple) {
SetHasMultipleOutput();
}
output->set_temporary(temporary);
if (temporary) {
SetHasTemporaryOutput();
}
}
void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false) {
AddOutput(name, comment, temporary, true);
} }
template <typename T> template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name, TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment) { const std::string& comment,
bool generated = false) {
auto attr = proto_->mutable_attrs()->Add(); auto attr = proto_->mutable_attrs()->Add();
*attr->mutable_name() = name; *attr->mutable_name() = name;
*attr->mutable_comment() = comment; *attr->mutable_comment() = comment;
attr->set_generated(generated);
AttrTypeHelper::SetAttrType<T>(attr); AttrTypeHelper::SetAttrType<T>(attr);
return op_checker_->AddAttrChecker<T>(name); return op_checker_->AddAttrChecker<T>(name);
} }
...@@ -86,8 +115,70 @@ class OpProtoAndCheckerMaker { ...@@ -86,8 +115,70 @@ class OpProtoAndCheckerMaker {
*(proto_->mutable_comment()) = comment; *(proto_->mutable_comment()) = comment;
} }
private:
void SetHasMultiple(const std::string& in_out, bool* flag) {
if (!*flag) {
AddAttr<std::vector<int>>(in_out + "_format",
"The multiple index of " + in_out +
"\n"
R"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.
e.g.
input = ["a", "b", "c", "d", "e", "f"]
input_format = [0, 4, 5, 6]
means
The number of all input variables this op is six, and they are segmented into
three inputs.
The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC",
/*generated*/ true);
*flag = true;
}
}
void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
void SetHasMultipleOutput() {
SetHasMultiple("output", &has_multiple_output_);
}
void SetHasTemporaryOutput() {
if (!has_temporary_output_) {
AddAttr<std::vector<int>>("temporary_index",
R"DOC(The temporary index of output.
Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.
Add a mark to which output is temporary is helpful for future optimization.
)DOC",
/*generated*/ true)
.SetDefault(std::vector<int>());
has_temporary_output_ = true;
}
}
void CheckNoDuplicatedAttrs() {
std::unordered_set<std::string> names;
size_t cnt = 0;
for (auto& attr : proto_->attrs()) {
names.insert(attr.name());
++cnt;
}
PADDLE_ENFORCE(names.size() == cnt,
"Cannot register two attribute in same name!");
}
OpProto* proto_; OpProto* proto_;
OpAttrChecker* op_checker_; OpAttrChecker* op_checker_;
bool has_multiple_input_{false};
bool has_multiple_output_{false};
bool has_temporary_output_{false};
}; };
class OpRegistry { class OpRegistry {
......
...@@ -36,8 +36,9 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -36,8 +36,9 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op"); AddInputs("input", "input of cosine op");
AddOutput("output", "output of cosine op"); AddOutput("output", "output of cosine op",
/*temporary*/ true);
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
}; };
...@@ -117,11 +118,20 @@ TEST(OpRegistry, DefaultValue) { ...@@ -117,11 +118,20 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
} }
static void SetInputFormat(paddle::framework::OpDesc* desc) {
auto attr = desc->add_attrs();
attr->set_name("input_format");
attr->set_type(paddle::framework::INTS);
attr->mutable_ints()->Add(0);
attr->mutable_ints()->Add(1);
}
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
op_desc.add_inputs("ii"); op_desc.add_inputs("ii");
op_desc.add_outputs("oo"); op_desc.add_outputs("oo");
SetInputFormat(&op_desc);
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
...@@ -163,6 +173,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -163,6 +173,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
SetInputFormat(&op_desc);
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册