From 00615ebca2217c9890b1e1212eba1f5d753aa92b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 26 Jul 2017 17:50:13 +0800 Subject: [PATCH] Refine OpRegistry::AddInput/AddOutput Remove bool argument, use a class to handle that. --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/backward_test.cc | 50 +++++++++++++++++++++++ paddle/framework/op_registry.h | 61 +++++++++++++++------------- paddle/framework/op_registry_test.cc | 5 +-- paddle/framework/operator_test.cc | 4 +- paddle/operators/fc_op.cc | 4 +- 6 files changed, 89 insertions(+), 36 deletions(-) create mode 100644 paddle/framework/backward_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 26d93336b13..66f516a9637 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -33,3 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_library(backward SRCS backward.cc DEPS net) +cc_test(backward_test SRCS backward_test.cc DEPS net) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc new file mode 100644 index 00000000000..b2286facfe4 --- /dev/null +++ b/paddle/framework/backward_test.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" +namespace paddle { +namespace framework { + +class EmptyOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr &scope) const override {} + void Run(const std::shared_ptr &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class RowwiseAddOp : public EmptyOp {}; +class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { + public: + RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input X of Add").IgnoreGradient(); + AddInput("b", "Bias of Add").IgnoreGradient(); + AddOutput("Out", "Out of Add").IgnoreGradient(); + AddComment("Add Op"); + } +}; + +class RowwiseAddGradOp : public EmptyOp {}; +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; +REGISTER_OP(rowwise_add, f::RowwiseAddOp, f::RowwiseAddOpMaker); +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::RowwiseAddGradOp); + +TEST(Backward, simple_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); +} \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 5bcd7ac9279..e4ac8a6e767 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker { } protected: - void AddInput(const std::string& name, const std::string& comment, - bool multiple = false, bool ignore_gradient = false) { + struct VariableBuilder { + VarProto* var_; + std::function on_multiple_; + std::function on_temporary_; + + VariableBuilder& SetMultiple() { + var_->set_multiple(true); + on_multiple_(); + return *this; + } + + VariableBuilder& SetTemporary() { + PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); + var_->set_temporary(true); + on_temporary_(); + return *this; + } + + VariableBuilder& IgnoreGradient() { + var_->set_ignore_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, + const std::string& comment) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; - input->set_ignore_gradient(ignore_gradient); - input->set_multiple(multiple); - if (multiple) { - SetHasMultipleInput(); - } - } - - void AddInputs(const std::string& name, const std::string& comment, - bool ignore_gradient = false) { - AddInput(name, comment, true, ignore_gradient); + return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, + nullptr}; } - void AddOutput(const std::string& name, const std::string& comment, - bool temporary = false, bool multiple = false, - bool ignore_gradient = false) { + VariableBuilder AddOutput(const std::string& name, + const std::string& comment) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; - output->set_ignore_gradient(ignore_gradient); - output->set_multiple(multiple); - if (multiple) { - SetHasMultipleOutput(); - } - output->set_temporary(temporary); - if (temporary) { - SetHasTemporaryOutput(); - } - } - - void AddOutputs(const std::string& name, const std::string& comment, - bool temporary = false, bool ignore_gradient = false) { - AddOutput(name, comment, temporary, true, ignore_gradient); + return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, + [=] { this->SetHasTemporaryOutput(); }}; } template diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 2ef781bf867..a534f661af3 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("input", "input of cosine op"); - AddOutput("output", "output of cosine op", - /*temporary*/ true); + AddInput("input", "input of cosine op").SetMultiple(); + AddOutput("output", "output of cosine op").SetTemporary(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 3fae356c3e5..839280abbc3 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("xs", "inputs of test op"); + AddInput("xs", "inputs of test op").SetMultiple(); AddInput("k", "input of test op"); - AddOutputs("ys", "outputs of test op"); + AddOutput("ys", "outputs of test op").SetMultiple(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc index c4a9f5937f4..71ceda95877 100644 --- a/paddle/operators/fc_op.cc +++ b/paddle/operators/fc_op.cc @@ -50,8 +50,8 @@ public: AddInput("b", "the bias of fc operator"); AddOutput("Y", "the output of fc operator"); - AddOutput( - "before_act", "the before activation output of fc operator", true); + AddOutput("before_act", "the before activation output of fc operator") + .SetTemporary(); AddAttr("activation", "The activation key for fc layer") .SetDefault("sigmoid") .InEnum({"sigmoid", "softmax"}); -- GitLab