未验证 提交 bfa3fd6f 编写于 作者: D dzhwinter 提交者: GitHub

add inplace attribute to op_proto_maker (#10665)

* "add inplace attribute"

* "register inplace attribute"

* "change se-next model for memory-reuse"

* "fix typo"

* repick

* fix merge conflict

* "fix stupid error"
上级 9087c668
...@@ -71,6 +71,7 @@ message OpProto { ...@@ -71,6 +71,7 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ]; optional bool dispensable = 5 [ default = false ];
optional string reuse = 6;
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
...@@ -21,6 +21,7 @@ namespace framework { ...@@ -21,6 +21,7 @@ namespace framework {
void OpProtoAndCheckerMaker::Validate() { void OpProtoAndCheckerMaker::Validate() {
validated_ = true; validated_ = true;
CheckNoDuplicatedInOutAttrs(); CheckNoDuplicatedInOutAttrs();
CheckReuseVars();
} }
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
...@@ -56,6 +57,24 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { ...@@ -56,6 +57,24 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
} }
} }
void OpProtoAndCheckerMaker::CheckReuseVars() {
std::unordered_set<std::string> names;
for (auto& input : proto_->inputs()) {
names.insert(input.name());
}
auto checker = [&](const std::string& name, const std::string& reused) {
PADDLE_ENFORCE(
names.count(reused),
"Output [%s] reuse Input [%s], but the input is not registered.", name,
reused);
};
for (auto& output : proto_->outputs()) {
if (output.has_reuse()) {
checker(output.name(), output.reuse());
}
}
}
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) { OpAttrChecker* attr_checker) {
proto_ = proto; proto_ = proto;
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -64,6 +66,11 @@ class OpProtoAndCheckerMaker { ...@@ -64,6 +66,11 @@ class OpProtoAndCheckerMaker {
var_->set_dispensable(true); var_->set_dispensable(true);
return *this; return *this;
} }
VariableBuilder &Reuse(const std::string &name) {
var_->set_reuse(name);
return *this;
}
}; };
VariableBuilder AddInput(const std::string &name, const std::string &comment); VariableBuilder AddInput(const std::string &name, const std::string &comment);
...@@ -89,6 +96,8 @@ class OpProtoAndCheckerMaker { ...@@ -89,6 +96,8 @@ class OpProtoAndCheckerMaker {
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
void Validate(); void Validate();
void CheckReuseVars();
proto::OpProto *proto_; proto::OpProto *proto_;
OpAttrChecker *op_checker_; OpAttrChecker *op_checker_;
bool validated_{false}; bool validated_{false};
......
...@@ -47,3 +47,23 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -47,3 +47,23 @@ TEST(ProtoMaker, DuplicatedInOut) {
ASSERT_THROW(proto_maker(&op_proto, &op_checker), ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
} }
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
AddOutput("NoOut", "output of test op").Reuse("NotExists");
}
};
TEST(ProtoMaker, InplaceOutput) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
TestInplaceProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet);
// proto_maker(&op_proto, &op_checker);
// proto_maker.Make();
// ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
public: \ public: \
void Make() override { \ void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \ AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \ AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddAttr<bool>("use_mkldnn", \ AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \ "(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \ .SetDefault(false); \
......
...@@ -89,9 +89,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -89,9 +89,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("ParamOut", "(Tensor) Output parameter").Reuse("Param");
AddOutput("Moment1Out", "(Tensor) Output first moment"); AddOutput("Moment1Out", "(Tensor) Output first moment").Reuse("Moment1");
AddOutput("Moment2Out", "(Tensor) Output second moment"); AddOutput("Moment2Out", "(Tensor) Output second moment").Reuse("Moment2");
AddAttr<float>("beta1", AddAttr<float>("beta1",
"(float, default 0.9) " "(float, default 0.9) "
......
...@@ -151,13 +151,15 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,13 +151,15 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Variance", AddInput("Variance",
"The global variance (for training) " "The global variance (for training) "
"or estimated Variance (for testing)"); "or estimated Variance (for testing)");
AddOutput("Y", "result after normalization"); AddOutput("Y", "result after normalization").Reuse("X");
AddOutput("MeanOut", AddOutput("MeanOut",
"Share memory with Mean. " "Share memory with Mean. "
"Store the global mean when training"); "Store the global mean when training")
.Reuse("Mean");
AddOutput("VarianceOut", AddOutput("VarianceOut",
"Share memory with Variance. " "Share memory with Variance. "
"Store the global Variance when training"); "Store the global Variance when training")
.Reuse("Variance");
AddOutput("SavedMean", AddOutput("SavedMean",
"Mean of the current mini batch, " "Mean of the current mini batch, "
"will apply to output when training") "will apply to output when training")
......
...@@ -125,7 +125,8 @@ void Conv2DOpMaker::Make() { ...@@ -125,7 +125,8 @@ void Conv2DOpMaker::Make() {
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.")
.Reuse("Input");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
...@@ -220,7 +221,8 @@ void Conv3DOpMaker::Make() { ...@@ -220,7 +221,8 @@ void Conv3DOpMaker::Make() {
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator." "(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.")
.Reuse("Input");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the " "(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of " "strides(d_stride, h_stride, w_stride) of "
......
...@@ -124,7 +124,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -124,7 +124,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Tensor<float/double> with shape [N x D]."); "Tensor<float/double> with shape [N x D].");
AddOutput("Y", AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape " "(Tensor, default Tensor<float>), a 2-D tensor with shape "
"[N x 1]. The cross entropy loss."); "[N x 1]. The cross entropy loss.")
.Reuse("X");
AddAttr<bool>("soft_label", AddAttr<bool>("soft_label",
"(bool, default false), a flag indicating whether to " "(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels.") "interpretate the given labels as soft labels.")
......
...@@ -59,7 +59,7 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,7 +59,7 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() final { void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op."); AddOutput("Out", "The output of elementwise op.").Reuse("X");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). The start dimension index " "(int, default -1). The start dimension index "
"for broadcasting Y onto X.") "for broadcasting Y onto X.")
......
...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "The input of mean op"); AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op"); AddOutput("Out", "The output of mean op").Reuse("X");
AddComment(R"DOC( AddComment(R"DOC(
Mean Operator. Mean Operator.
......
...@@ -151,7 +151,8 @@ void Pool2dOpMaker::Make() { ...@@ -151,7 +151,8 @@ void Pool2dOpMaker::Make() {
"The format of output tensor is also NCHW, " "The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, " "where N is batch size, C is the number of channels, "
"H is the height of the feature, " "H is the height of the feature, "
"and W is the width of the feature."); "and W is the width of the feature.")
.Reuse("X");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("pooling_type",
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
...@@ -244,7 +245,8 @@ void Pool3dOpMaker::Make() { ...@@ -244,7 +245,8 @@ void Pool3dOpMaker::Make() {
"The format of output tensor is also NCDHW, " "The format of output tensor is also NCDHW, "
"where N is batch size, C is " "where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and " "the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively."); "width of the feature, respectively.")
.Reuse("X");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("pooling_type",
"(string) Pooling type, can be \"max\" for max-pooling " "(string) Pooling type, can be \"max\" for max-pooling "
......
...@@ -74,7 +74,8 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,7 +74,8 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut", AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) " "(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param"); "Output parameter, should share the same memory with Param")
.Reuse("Param");
AddComment(R"DOC( AddComment(R"DOC(
SGD operator SGD operator
......
...@@ -83,7 +83,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,7 +83,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "The normalized values with the same shape as X."); AddOutput("Out", "The normalized values with the same shape as X.")
.Reuse("X");
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
......
...@@ -115,7 +115,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -115,7 +115,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator."); AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X");
AddComment(R"DOC( AddComment(R"DOC(
Sum operator. Sum operator.
......
...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC( AddComment(R"DOC(
Top K operator Top K operator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册