提交 00615ebc 编写于 作者: Y Yu Yang

Refine OpRegistry::AddInput/AddOutput

Remove bool argument, use a class to handle that.
上级 b1b13f8f
...@@ -33,3 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry) ...@@ -33,3 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
cc_library(backward SRCS backward.cc DEPS net) cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS net)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
class EmptyOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope> &scope) const override {}
void Run(const std::shared_ptr<Scope> &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class RowwiseAddOp : public EmptyOp {};
class RowwiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient();
AddInput("b", "Bias of Add").IgnoreGradient();
AddOutput("Out", "Out of Add").IgnoreGradient();
AddComment("Add Op");
}
};
class RowwiseAddGradOp : public EmptyOp {};
} // namespace framework
} // namespace paddle
namespace f = paddle::framework;
REGISTER_OP(rowwise_add, f::RowwiseAddOp, f::RowwiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::RowwiseAddGradOp);
TEST(Backward, simple_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
}
\ No newline at end of file
...@@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker { ...@@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker {
} }
protected: protected:
void AddInput(const std::string& name, const std::string& comment, struct VariableBuilder {
bool multiple = false, bool ignore_gradient = false) { VarProto* var_;
auto input = proto_->mutable_inputs()->Add(); std::function<void()> on_multiple_;
*input->mutable_name() = name; std::function<void()> on_temporary_;
*input->mutable_comment() = comment;
input->set_ignore_gradient(ignore_gradient); VariableBuilder& SetMultiple() {
input->set_multiple(multiple); var_->set_multiple(true);
if (multiple) { on_multiple_();
SetHasMultipleInput(); return *this;
} }
VariableBuilder& SetTemporary() {
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
var_->set_temporary(true);
on_temporary_();
return *this;
} }
void AddInputs(const std::string& name, const std::string& comment, VariableBuilder& IgnoreGradient() {
bool ignore_gradient = false) { var_->set_ignore_gradient(true);
AddInput(name, comment, true, ignore_gradient); return *this;
} }
};
void AddOutput(const std::string& name, const std::string& comment, VariableBuilder AddInput(const std::string& name,
bool temporary = false, bool multiple = false, const std::string& comment) {
bool ignore_gradient = false) { auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name;
*input->mutable_comment() = comment;
return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
nullptr};
}
VariableBuilder AddOutput(const std::string& name,
const std::string& comment) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name; *output->mutable_name() = name;
*output->mutable_comment() = comment; *output->mutable_comment() = comment;
output->set_ignore_gradient(ignore_gradient); return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
output->set_multiple(multiple); [=] { this->SetHasTemporaryOutput(); }};
if (multiple) {
SetHasMultipleOutput();
}
output->set_temporary(temporary);
if (temporary) {
SetHasTemporaryOutput();
}
}
void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false, bool ignore_gradient = false) {
AddOutput(name, comment, temporary, true, ignore_gradient);
} }
template <typename T> template <typename T>
......
...@@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("input", "input of cosine op"); AddInput("input", "input of cosine op").SetMultiple();
AddOutput("output", "output of cosine op", AddOutput("output", "output of cosine op").SetTemporary();
/*temporary*/ true);
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
}; };
......
...@@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInputs("xs", "inputs of test op"); AddInput("xs", "inputs of test op").SetMultiple();
AddInput("k", "input of test op"); AddInput("k", "input of test op");
AddOutputs("ys", "outputs of test op"); AddOutput("ys", "outputs of test op").SetMultiple();
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
......
...@@ -50,8 +50,8 @@ public: ...@@ -50,8 +50,8 @@ public:
AddInput("b", "the bias of fc operator"); AddInput("b", "the bias of fc operator");
AddOutput("Y", "the output of fc operator"); AddOutput("Y", "the output of fc operator");
AddOutput( AddOutput("before_act", "the before activation output of fc operator")
"before_act", "the before activation output of fc operator", true); .SetTemporary();
AddAttr<std::string>("activation", "The activation key for fc layer") AddAttr<std::string>("activation", "The activation key for fc layer")
.SetDefault("sigmoid") .SetDefault("sigmoid")
.InEnum({"sigmoid", "softmax"}); .InEnum({"sigmoid", "softmax"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册