From a7a66b80f70aef70e2adb7b75ca22d87f33885f6 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 19 Sep 2017 22:28:16 -0700 Subject: [PATCH] move OpProtoAndCheckerMaker from operator to op_proto_maker --- paddle/framework/CMakeLists.txt | 4 +- paddle/framework/op_proto_maker.cc | 58 ++++++++++++++++ paddle/framework/op_proto_maker.h | 88 +++++++++++++++++++++++++ paddle/framework/op_proto_maker_test.cc | 51 ++++++++++++++ paddle/framework/op_registry.h | 1 + paddle/framework/operator.cc | 38 ----------- paddle/framework/operator.h | 65 ------------------ paddle/framework/operator_test.cc | 34 ---------- paddle/operators/prelu_op.h | 2 +- 9 files changed, 202 insertions(+), 139 deletions(-) create mode 100644 paddle/framework/op_proto_maker.cc create mode 100644 paddle/framework/op_proto_maker.h create mode 100644 paddle/framework/op_proto_maker_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 3371962c635..e535f84dba7 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,12 +19,14 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) +cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) +cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) diff --git a/paddle/framework/op_proto_maker.cc b/paddle/framework/op_proto_maker.cc new file mode 100644 index 00000000000..151d61d5b17 --- /dev/null +++ b/paddle/framework/op_proto_maker.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { + +void OpProtoAndCheckerMaker::Validate() { + validated_ = true; + CheckNoDuplicatedInOutAttrs(); +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( + const std::string& name, const std::string& comment) { + auto* input = proto_->add_inputs(); + input->set_name(name); + input->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{input}; +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( + const std::string& name, const std::string& comment) { + auto* output = proto_->add_outputs(); + output->set_name(name); + output->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{output}; +} + +void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { + std::unordered_set names; + auto checker = [&](const std::string& name) { + PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); + names.insert(name); + }; + for (auto& attr : proto_->attrs()) { + checker(attr.name()); + } + for (auto& input : proto_->inputs()) { + checker(input.name()); + } + for (auto& output : proto_->outputs()) { + checker(output.name()); + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_proto_maker.h b/paddle/framework/op_proto_maker.h new file mode 100644 index 00000000000..fea15a374b1 --- /dev/null +++ b/paddle/framework/op_proto_maker.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/attribute.h" +#include "paddle/framework/framework.pb.h" + +namespace paddle { +namespace framework { + +// this class not only make proto but also init attribute checkers. +class OpProtoAndCheckerMaker { + public: + OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : proto_(proto), op_checker_(op_checker) {} + + ~OpProtoAndCheckerMaker() { + PADDLE_ENFORCE(validated_, "should call Validate after build"); + } + + void Validate(); + + protected: + struct VariableBuilder { + OpProto::Var* var_; + + VariableBuilder& AsDuplicable() { + var_->set_duplicable(true); + return *this; + } + + VariableBuilder& AsIntermediate() { + var_->set_intermediate(true); + return *this; + } + + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, const std::string& comment); + + VariableBuilder AddOutput(const std::string& name, + const std::string& comment); + + template + TypedAttrChecker& AddAttr(const std::string& name, + const std::string& comment, + bool generated = false) { + auto* attr = proto_->add_attrs(); + attr->set_name(name); + attr->set_comment(comment); + attr->set_generated(generated); + attr->set_type(AttrTypeID()); + return op_checker_->AddAttrChecker(name); + } + + void AddComment(const std::string& comment) { proto_->set_comment(comment); } + + private: + void CheckNoDuplicatedInOutAttrs(); + + OpProto* proto_; + OpAttrChecker* op_checker_; + bool validated_{false}; +}; + +class NOPMaker : public OpProtoAndCheckerMaker { + public: + NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_proto_maker_test.cc b/paddle/framework/op_proto_maker_test.cc new file mode 100644 index 00000000000..b01e30f7537 --- /dev/null +++ b/paddle/framework/op_proto_maker_test.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_proto_maker.h" + +#include "gtest/gtest.h" + +class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { + public: + TestAttrProtoMaker(paddle::framework::OpProto* proto, + paddle::framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("scale", "scale of test op"); + AddAttr("scale", "scale of test op"); + } +}; + +TEST(ProtoMaker, DuplicatedAttr) { + paddle::framework::OpProto op_proto; + paddle::framework::OpAttrChecker op_checker; + auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); + ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); +} + +class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { + public: + TestInOutProtoMaker(paddle::framework::OpProto* proto, + paddle::framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddInput("input", "input of test op"); + } +}; + +TEST(ProtoMaker, DuplicatedInOut) { + paddle::framework::OpProto op_proto; + paddle::framework::OpAttrChecker op_checker; + auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); + ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); +} \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 572dff860a3..90077d01924 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" #include "paddle/framework/op_info.h" +#include "paddle/framework/op_proto_maker.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index f8a64a78661..49509af6630 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -228,43 +228,5 @@ std::vector ExecutionContext::MultiOutput( return res; } -void OpProtoAndCheckerMaker::Validate() { - validated_ = true; - CheckNoDuplicatedInOutAttrs(); -} - -OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( - const std::string& name, const std::string& comment) { - auto* input = proto_->add_inputs(); - input->set_name(name); - input->set_comment(comment); - return OpProtoAndCheckerMaker::VariableBuilder{input}; -} - -OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( - const std::string& name, const std::string& comment) { - auto* output = proto_->add_outputs(); - output->set_name(name); - output->set_comment(comment); - return OpProtoAndCheckerMaker::VariableBuilder{output}; -} - -void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { - std::unordered_set names; - auto checker = [&](const std::string& name) { - PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); - names.insert(name); - }; - for (auto& attr : proto_->attrs()) { - checker(attr.name()); - } - for (auto& input : proto_->inputs()) { - checker(input.name()); - } - for (auto& output : proto_->outputs()) { - checker(output.name()); - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index b7c9c39402d..1a78b6d1e14 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -167,71 +167,6 @@ class NOP : public OperatorBase { } }; -// this class not only make proto but also init attribute checkers. -class OpProtoAndCheckerMaker { - public: - OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : proto_(proto), op_checker_(op_checker) {} - - ~OpProtoAndCheckerMaker() { - PADDLE_ENFORCE(validated_, "should call Validate after build"); - } - - void Validate(); - - protected: - struct VariableBuilder { - OpProto::Var* var_; - - VariableBuilder& AsDuplicable() { - var_->set_duplicable(true); - return *this; - } - - VariableBuilder& AsIntermediate() { - var_->set_intermediate(true); - return *this; - } - - VariableBuilder& NotInGradient() { - var_->set_not_in_gradient(true); - return *this; - } - }; - - VariableBuilder AddInput(const std::string& name, const std::string& comment); - - VariableBuilder AddOutput(const std::string& name, - const std::string& comment); - - template - TypedAttrChecker& AddAttr(const std::string& name, - const std::string& comment, - bool generated = false) { - auto* attr = proto_->add_attrs(); - attr->set_name(name); - attr->set_comment(comment); - attr->set_generated(generated); - attr->set_type(AttrTypeID()); - return op_checker_->AddAttrChecker(name); - } - - void AddComment(const std::string& comment) { proto_->set_comment(comment); } - - private: - void CheckNoDuplicatedInOutAttrs(); - - OpProto* proto_; - OpAttrChecker* op_checker_; - bool validated_{false}; -}; - -class NOPMaker : public OpProtoAndCheckerMaker { - public: - NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) {} -}; - class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 20bbb11896a..0beab0fac5b 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -264,37 +264,3 @@ TEST(Operator, Clone) { auto b = a.Clone(); ASSERT_EQ(a.Type(), b->Type()); } - -class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { - public: - TestAttrProtoMaker(paddle::framework::OpProto* proto, - paddle::framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("scale", "scale of test op"); - AddAttr("scale", "scale of test op"); - } -}; - -TEST(ProtoMaker, DuplicatedAttr) { - paddle::framework::OpProto op_proto; - paddle::framework::OpAttrChecker op_checker; - auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); -} - -class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { - public: - TestInOutProtoMaker(paddle::framework::OpProto* proto, - paddle::framework::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of test op"); - AddInput("input", "input of test op"); - } -}; - -TEST(ProtoMaker, DuplicatedInOut) { - paddle::framework::OpProto op_proto; - paddle::framework::OpAttrChecker op_checker; - auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); -} \ No newline at end of file diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index 3269116c112..6b78ed295cb 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -96,7 +96,7 @@ class PReluGradKernel : public framework::OpKernel { trans(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor(alpha_ptr)); - // TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready + // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready } }; -- GitLab