diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 68304c9fc8b8fa13cb1f99b82517abc87c71496c..59012ea8c1e52e05f402706b3bfff5a4eee2715c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) - -cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute) +cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 2c5ec76dfeb8b8485951e4d94896b6758e0cb930..bcdfae132c6bfd924b8c8383ec69478f77ceaa22 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -72,8 +72,8 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { class FcOp : public operators::NetOp { public: - FcOp(const std::string &type, const VarNameMap &inputs, - const VarNameMap &outputs, const AttributeMap &attrs) + FcOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { AddOp(OpRegistry::CreateOp("mul", {{"X", {Input("X")}}, {"Y", {Input("W")}}}, diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 0a2a41f6b62658ac8633a6e384d099f8d6641f33..fcc5d7a216a6f1bb3b0752ba902180991d7768b8 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,11 +20,11 @@ namespace framework { enum class OpArgType { IN, OUT }; static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, - bool is_grad, OperatorBase::VarNameMap* vars) { + bool is_grad, VariableNameMap* vars) { const auto& src_inout = src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); auto& dst_inout = *vars; - const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_; + const OpProto* proto = OpInfoMap().at(src_op->Type()).proto_; const auto& src_arg_list = src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { @@ -40,25 +40,25 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, } OperatorBase* BuildGradOp(const OperatorBase* op) { - auto it = OpRegistry::op_info_map().find(op->Type()); - PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", op->Type()); + auto it = OpInfoMap().find(op->Type()); + PADDLE_ENFORCE(it != OpInfoMap().end(), "'%s' has not been registered.", + op->Type()); PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", op->Type()); std::string grad_op_type = it->second.grad_op_type_; PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", op->Type()); - OperatorBase::VarNameMap inputs; - OperatorBase::VarNameMap outputs; + VariableNameMap inputs; + VariableNameMap outputs; TransOpArg(op, OpArgType::IN, false, &inputs); // I TransOpArg(op, OpArgType::OUT, false, &inputs); // O TransOpArg(op, OpArgType::OUT, true, &inputs); // OG TransOpArg(op, OpArgType::IN, true, &outputs); // IG - it = OpRegistry::op_info_map().find(grad_op_type); - PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", grad_op_type); + it = OpInfoMap().find(grad_op_type); + PADDLE_ENFORCE(it != OpInfoMap().end(), "'%s' has not been registered.", + grad_op_type); return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs()); } diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..f928ac647398d43f9e1956cf57e0b5e86695da43 --- /dev/null +++ b/paddle/framework/op_info.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_info.h" + +namespace paddle { +namespace framework { + +static std::unordered_map* + g_op_info_map = nullptr; +std::unordered_map& OpInfoMap() { + if (g_op_info_map == nullptr) { + g_op_info_map = + new std::unordered_map(); + } + return *g_op_info_map; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h new file mode 100644 index 0000000000000000000000000000000000000000..fdd0ed77d4b0ac1f2cd03fd93cf04550ac977950 --- /dev/null +++ b/paddle/framework/op_info.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/framework/attribute.h" + +namespace paddle { +namespace framework { +class OperatorBase; +using VariableNameMap = std::map>; + +using OpCreator = std::function; + +struct OpInfo { + OpCreator creator_; + std::string grad_op_type_; + OpProto* proto_; + OpAttrChecker* checker_; +}; + +extern std::unordered_map& OpInfoMap(); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 8eae86e9605da74cdc37caeb9569e7500aac2a63..e03dc3a73df47af465b0824d58345861659f0797 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -19,32 +19,20 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr OpRegistry::CreateOp(const std::string& type, - const VarNameMap& inputs, - const VarNameMap& outputs, - AttributeMap attrs) { - auto it = op_info_map().find(type); - PADDLE_ENFORCE(it != op_info_map().end(), +std::unique_ptr OpRegistry::CreateOp( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs) { + auto it = OpInfoMap().find(type); + PADDLE_ENFORCE(it != OpInfoMap().end(), "Operator '%s' has not been registered.", type); it->second.checker_->Check(attrs); auto op = it->second.creator_(type, inputs, outputs, attrs); return std::unique_ptr(op); } -std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { - VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); - VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); - AttributeMap attrs; - for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); - } - - return CreateOp(op_desc.type(), inputs, outputs, attrs); -} - -OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( +static VariableNameMap ConvertOpDescVarsToVarNameMap( const google::protobuf::RepeatedPtrField& op_desc_vars) { - VarNameMap ret_val; + VariableNameMap ret_val; for (auto& var : op_desc_vars) { auto& var_names = ret_val[var.parameter()]; auto& var_names_in_proto = var.arguments(); @@ -55,6 +43,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( return ret_val; } +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { + VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); + VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); + AttributeMap attrs; + for (auto& attr : op_desc.attrs()) { + attrs[attr.name()] = GetAttrValue(attr); + } + + return CreateOp(op_desc.type(), inputs, outputs, attrs); +} + std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); return std::unique_ptr(BuildGradOp(&op)); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4c2d13d639005d2d2710c19f63988333d89bce13..06530bc7d0dcf489a910a28335725deff93e8261 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" @@ -30,28 +31,16 @@ namespace paddle { namespace framework { class OpRegistry { - using VarNameMap = OperatorBase::VarNameMap; - using OpCreator = std::function; - public: - struct OpInfo { - OpCreator creator_; - std::string grad_op_type_; - OpProto* proto_; - OpAttrChecker* checker_; - }; - template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { - PADDLE_ENFORCE(op_info_map().count(op_type) == 0, + PADDLE_ENFORCE(OpInfoMap().count(op_type) == 0, "'%s' is registered more than once.", op_type); OpInfo op_info; - op_info.creator_ = [](const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, - const AttributeMap& attrs) { + op_info.creator_ = []( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; op_info.grad_op_type_ = grad_op_type; @@ -70,7 +59,7 @@ class OpRegistry { op_info.proto_ = nullptr; op_info.checker_ = nullptr; } - op_info_map().insert(std::make_pair(op_type, op_info)); + OpInfoMap().insert(std::make_pair(op_type, op_info)); // register gradient op if (!grad_op_type.empty()) { RegisterOp(grad_op_type, ""); @@ -78,21 +67,13 @@ class OpRegistry { } static std::unique_ptr CreateOp(const std::string& type, - const VarNameMap& inputs, - const VarNameMap& outputs, + const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); - static VarNameMap ConvertOpDescVarsToVarNameMap( - const google::protobuf::RepeatedPtrField& op_desc_vars); - static std::unique_ptr CreateGradOp(const OperatorBase& op); - - static std::unordered_map& op_info_map() { - static std::unordered_map op_info_map_; - return op_info_map_; - } }; class Registrar { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index eadd8f3316ff1ebffb94a56b2e62d661e4e0b38f..48a7fe64ac60db89857a85cba9c069463f10aa49 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name, } OperatorBase::OperatorBase(const std::string& type, - const OperatorBase::VarNameMap& inputs, - const OperatorBase::VarNameMap& outputs, + const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { static std::atomic gUniqId(0UL); @@ -141,9 +141,9 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { } return ret_val; } - auto it = OpRegistry::op_info_map().find(type_); + auto it = OpInfoMap().find(type_); PADDLE_ENFORCE( - it != OpRegistry::op_info_map().end(), + it != OpInfoMap().end(), "Operator %s not registered, cannot figure out intermediate outputs", type_); PADDLE_ENFORCE( diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 807298088981b969622174be753ea0da72067243..83dab8631d583956a7a2b09cd037aae05355041f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include +#include "op_info.h" #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/scope.h" @@ -62,10 +63,8 @@ class ExecutionContext; */ class OperatorBase { public: - using VarNameMap = std::map>; - - OperatorBase(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs); + OperatorBase(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs); virtual ~OperatorBase() {} @@ -93,8 +92,8 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); - const VarNameMap& Inputs() const { return inputs_; } - const VarNameMap& Outputs() const { return outputs_; } + const VariableNameMap& Inputs() const { return inputs_; } + const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -122,11 +121,11 @@ class OperatorBase { // I (Inputs)opear // O (Outputs) // OG (Output Gradients) - VarNameMap inputs_; + VariableNameMap inputs_; // NOTE: in case of OpGrad, outputs_ contains // IG (Inputs Gradients) - VarNameMap outputs_; + VariableNameMap outputs_; AttributeMap attrs_; }; @@ -142,9 +141,11 @@ class OperatorBase { // You can also use // using PARENT_CLASS::PARENT_CLASS; // to use parent's constructor. -#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ - CLS(const std::string& type, const VarNameMap& inputs, \ - const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ +#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ + CLS(const std::string& type, \ + const ::paddle::framework::VariableNameMap& inputs, \ + const ::paddle::framework::VariableNameMap& outputs, \ + const paddle::framework::AttributeMap& attrs) \ : PARENT_CLS(type, inputs, outputs, attrs) {} class NOP : public OperatorBase { @@ -389,8 +390,8 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - OperatorWithKernel(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) + OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} void InferShape(const Scope& scope) const override { diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 2425b87779f6af01b0e8a91b5f574a28385f0efd..1d7efb7b9403f7c1c6bdbb27a0258f79ae032f43 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,8 +23,8 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: - OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) + OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), x(1) {} void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, @@ -249,8 +249,9 @@ TEST(OpKernel, multi_inputs) { class OperatorClone : public paddle::framework::OperatorBase { public: DEFINE_OP_CLONE_METHOD(OperatorClone); - OperatorClone(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, + OperatorClone(const std::string& type, + const paddle::framework::VariableNameMap& inputs, + const paddle::framework::VariableNameMap& outputs, const paddle::framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} void InferShape(const paddle::framework::Scope& scope) const override {} diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index f0114b9e4908d65b3fddb493230777f9e500b4e1..1aec48357320e8630c5957785e26c409251d2a20 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -138,7 +138,7 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &op_info_map = OpRegistry::op_info_map(); + auto &op_info_map = OpInfoMap(); std::vector ret_values; for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) { const OpProto *proto = it->second.proto_; diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index a7d710511093dfbe13a13b1222b0230bba0398bd..9bfa712d986a386c14aef7c1a16b5bc5ff11a27f 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -81,9 +81,8 @@ std::vector NetOp::OutputVars(bool has_intermediate) const { return ret_val; } -NetOp::NetOp(const std::string& type, - const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, +NetOp::NetOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 885ac6eeca65998dea62c1db40b9261cceb97805..05b475d88f415fd1135a5fc6c7054f084928e8e5 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -38,8 +38,10 @@ class NetOp : public framework::OperatorBase { public: static const char kAll[]; NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {} - NetOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const framework::AttributeMap& attrs); + + NetOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) { this->ops_.reserve(o.ops_.size()); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 78ce0ba3c0fa4fe380e49a848c2434fe593cd00b..16bd249cb3d989c695ec9378f09d48833d70be58 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -131,8 +131,8 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{ "memories", "pre_memories", "boot_memories@grad"}; RecurrentOp::RecurrentOp(const std::string& type, - const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); @@ -223,8 +223,8 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { } RecurrentGradientOp::RecurrentGradientOp( - const std::string& type, const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, + const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index bcfa817de8242153b164fa091309f19a6ad8a246..1033d657a3a8f96c8b3dae8dd93d3f1f6840b59b 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -114,8 +114,9 @@ class RecurrentGradientAlgorithm { class RecurrentOp : public framework::OperatorBase { public: - RecurrentOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); RecurrentOp(const RecurrentOp& o) : framework::OperatorBase( @@ -150,8 +151,9 @@ class RecurrentOp : public framework::OperatorBase { class RecurrentGradientOp : public framework::OperatorBase { public: - RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, + RecurrentGradientOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs); RecurrentGradientOp(const RecurrentGradientOp& o)