未验证 提交 bfa3fd6f 编写于 作者: D dzhwinter 提交者: GitHub

add inplace attribute to op_proto_maker (#10665)

* "add inplace attribute"

* "register inplace attribute"

* "change se-next model for memory-reuse"

* "fix typo"

* repick

* fix merge conflict

* "fix stupid error"
上级 9087c668
......@@ -71,6 +71,7 @@ message OpProto {
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ];
optional string reuse = 6;
}
// AttrProto describes the C++ type Attribute.
......
......@@ -21,6 +21,7 @@ namespace framework {
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
CheckReuseVars();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
......@@ -56,6 +57,24 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
}
}
void OpProtoAndCheckerMaker::CheckReuseVars() {
std::unordered_set<std::string> names;
for (auto& input : proto_->inputs()) {
names.insert(input.name());
}
auto checker = [&](const std::string& name, const std::string& reused) {
PADDLE_ENFORCE(
names.count(reused),
"Output [%s] reuse Input [%s], but the input is not registered.", name,
reused);
};
for (auto& output : proto_->outputs()) {
if (output.has_reuse()) {
checker(output.name(), output.reuse());
}
}
}
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) {
proto_ = proto;
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
......@@ -64,6 +66,11 @@ class OpProtoAndCheckerMaker {
var_->set_dispensable(true);
return *this;
}
VariableBuilder &Reuse(const std::string &name) {
var_->set_reuse(name);
return *this;
}
};
VariableBuilder AddInput(const std::string &name, const std::string &comment);
......@@ -89,6 +96,8 @@ class OpProtoAndCheckerMaker {
void CheckNoDuplicatedInOutAttrs();
void Validate();
void CheckReuseVars();
proto::OpProto *proto_;
OpAttrChecker *op_checker_;
bool validated_{false};
......
......@@ -47,3 +47,23 @@ TEST(ProtoMaker, DuplicatedInOut) {
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet);
}
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
AddOutput("NoOut", "output of test op").Reuse("NotExists");
}
};
TEST(ProtoMaker, InplaceOutput) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
TestInplaceProtoMaker proto_maker;
ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet);
// proto_maker(&op_proto, &op_checker);
// proto_maker.Make();
// ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
}
......@@ -25,7 +25,7 @@ namespace operators {
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
......
......@@ -89,9 +89,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment");
AddOutput("ParamOut", "(Tensor) Output parameter").Reuse("Param");
AddOutput("Moment1Out", "(Tensor) Output first moment").Reuse("Moment1");
AddOutput("Moment2Out", "(Tensor) Output second moment").Reuse("Moment2");
AddAttr<float>("beta1",
"(float, default 0.9) "
......
......@@ -151,13 +151,15 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Variance",
"The global variance (for training) "
"or estimated Variance (for testing)");
AddOutput("Y", "result after normalization");
AddOutput("Y", "result after normalization").Reuse("X");
AddOutput("MeanOut",
"Share memory with Mean. "
"Store the global mean when training");
"Store the global mean when training")
.Reuse("Mean");
AddOutput("VarianceOut",
"Share memory with Variance. "
"Store the global Variance when training");
"Store the global Variance when training")
.Reuse("Variance");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training")
......
......@@ -125,7 +125,8 @@ void Conv2DOpMaker::Make() {
"input image channels divided by the groups.");
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
"The format of output tensor is also NCHW.")
.Reuse("Input");
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
......@@ -220,7 +221,8 @@ void Conv3DOpMaker::Make() {
"input image channels divided by the groups.");
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
"The format of output tensor is also NCDHW.")
.Reuse("Input");
AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
......
......@@ -124,7 +124,8 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Tensor<float/double> with shape [N x D].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape "
"[N x 1]. The cross entropy loss.");
"[N x 1]. The cross entropy loss.")
.Reuse("X");
AddAttr<bool>("soft_label",
"(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels.")
......
......@@ -59,7 +59,7 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
AddOutput("Out", "The output of elementwise op.");
AddOutput("Out", "The output of elementwise op.").Reuse("X");
AddAttr<int>("axis",
"(int, default -1). The start dimension index "
"for broadcasting Y onto X.")
......
......@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op");
AddOutput("Out", "The output of mean op").Reuse("X");
AddComment(R"DOC(
Mean Operator.
......
......@@ -151,7 +151,8 @@ void Pool2dOpMaker::Make() {
"The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, "
"H is the height of the feature, "
"and W is the width of the feature.");
"and W is the width of the feature.")
.Reuse("X");
AddAttr<std::string>("pooling_type",
"(string), pooling type, can be \"max\" for max-pooling "
......@@ -244,7 +245,8 @@ void Pool3dOpMaker::Make() {
"The format of output tensor is also NCDHW, "
"where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively.");
"width of the feature, respectively.")
.Reuse("X");
AddAttr<std::string>("pooling_type",
"(string) Pooling type, can be \"max\" for max-pooling "
......
......@@ -74,7 +74,8 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param");
"Output parameter, should share the same memory with Param")
.Reuse("Param");
AddComment(R"DOC(
SGD operator
......
......@@ -83,7 +83,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "The normalized values with the same shape as X.");
AddOutput("Out", "The normalized values with the same shape as X.")
.Reuse("X");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
......
......@@ -115,7 +115,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator.");
AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X");
AddComment(R"DOC(
Sum operator.
......
......@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC(
Top K operator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册