diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 2cdf323c53a8ba729ec74c1eacb9fa3ef272f44a..eb19defd5f13b22587062e99ae078d9a3635131a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) - -cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute) +cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index b93ab66f2f5b9cffa6d51b6e36afe552125970e4..f100c4d05489ac3bd4ceb5f11ae871985f0e5d83 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -72,8 +72,8 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { class FcOp : public operators::NetOp { public: - FcOp(const std::string &type, const VarNameMap &inputs, - const VarNameMap &outputs, const AttributeMap &attrs) + FcOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs) : NetOp(type, inputs, outputs, attrs) { AppendOp(OpRegistry::CreateOp("mul", {{"X", {Input("X")}}, {"Y", {Input("W")}}}, diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 0a2a41f6b62658ac8633a6e384d099f8d6641f33..b02a599a800668b22e7fe39a10fa6dc132e305bd 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,13 +20,13 @@ namespace framework { enum class OpArgType { IN, OUT }; static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, - bool is_grad, OperatorBase::VarNameMap* vars) { + bool is_grad, VariableNameMap* vars) { const auto& src_inout = src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); auto& dst_inout = *vars; - const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_; + auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); const auto& src_arg_list = - src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); + src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); for (const auto& arg : src_arg_list) { if (arg.not_in_gradient() && !is_grad) continue; const std::string src_name = arg.name(); @@ -40,26 +40,18 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, } OperatorBase* BuildGradOp(const OperatorBase* op) { - auto it = OpRegistry::op_info_map().find(op->Type()); - PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", op->Type()); - PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", - op->Type()); - std::string grad_op_type = it->second.grad_op_type_; - PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", - op->Type()); + auto& info = OpInfoMap::Instance().Get(op->Type()); + PADDLE_ENFORCE(info.HasGradientOp()); - OperatorBase::VarNameMap inputs; - OperatorBase::VarNameMap outputs; + VariableNameMap inputs; + VariableNameMap outputs; TransOpArg(op, OpArgType::IN, false, &inputs); // I TransOpArg(op, OpArgType::OUT, false, &inputs); // O TransOpArg(op, OpArgType::OUT, true, &inputs); // OG TransOpArg(op, OpArgType::IN, true, &outputs); // IG - it = OpRegistry::op_info_map().find(grad_op_type); - PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), - "'%s' has not been registered.", grad_op_type); - return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs()); + auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_); + return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); } } // namespace framework diff --git a/paddle/framework/op_info.cc b/paddle/framework/op_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..81ba29797c5f478e5d6a91236f3e8de1e6b43e49 --- /dev/null +++ b/paddle/framework/op_info.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_info.h" + +namespace paddle { +namespace framework { + +static OpInfoMap* g_op_info_map = nullptr; + +OpInfoMap& OpInfoMap::Instance() { + if (g_op_info_map == nullptr) { + g_op_info_map = new OpInfoMap(); + } + return *g_op_info_map; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h new file mode 100644 index 0000000000000000000000000000000000000000..94245c6c44aca962b0db890947a9dc5550ac0799 --- /dev/null +++ b/paddle/framework/op_info.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include +#include + +#include "paddle/framework/attribute.h" + +namespace paddle { +namespace framework { +class OperatorBase; +using VariableNameMap = std::map>; + +using OpCreator = std::function; + +struct OpInfo { + OpCreator creator_; + std::string grad_op_type_; + OpProto* proto_; + OpAttrChecker* checker_; + + bool HasOpProtoAndChecker() const { + return proto_ != nullptr && checker_ != nullptr; + } + + const OpProto& Proto() const { + PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered"); + PADDLE_ENFORCE(proto_->IsInitialized(), + "Operator Proto must be initialized in op info"); + return *proto_; + } + + const OpAttrChecker& Checker() const { + PADDLE_ENFORCE_NOT_NULL(checker_, + "Operator Checker has not been registered"); + return *checker_; + } + + const OpCreator& Creator() const { + PADDLE_ENFORCE_NOT_NULL(creator_, + "Operator Creator has not been registered"); + return creator_; + } + + bool HasGradientOp() const { return !grad_op_type_.empty(); } +}; + +class OpInfoMap { + public: + static OpInfoMap& Instance(); + + OpInfoMap(const OpInfoMap& o) = delete; + OpInfoMap(OpInfoMap&& o) = delete; + OpInfoMap& operator=(const OpInfoMap& o) = delete; + OpInfoMap& operator=(OpInfoMap&& o) = delete; + + bool Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); + } + + void Insert(const std::string& type, const OpInfo& info) { + PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type); + map_.insert({type, info}); + } + + const OpInfo& Get(const std::string& type) const { + auto it = map_.find(type); + PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); + return it->second; + } + + template + void IterAllInfo(Callback callback) { + for (auto& it : map_) { + callback(it.first, it.second); + } + } + + private: + OpInfoMap() = default; + std::unordered_map map_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 8eae86e9605da74cdc37caeb9569e7500aac2a63..b0e85dd49f97da4a7f889fde0b5f060954947be8 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -19,32 +19,18 @@ limitations under the License. */ namespace paddle { namespace framework { -std::unique_ptr OpRegistry::CreateOp(const std::string& type, - const VarNameMap& inputs, - const VarNameMap& outputs, - AttributeMap attrs) { - auto it = op_info_map().find(type); - PADDLE_ENFORCE(it != op_info_map().end(), - "Operator '%s' has not been registered.", type); - it->second.checker_->Check(attrs); - auto op = it->second.creator_(type, inputs, outputs, attrs); +std::unique_ptr OpRegistry::CreateOp( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs) { + auto& info = OpInfoMap::Instance().Get(type); + info.Checker().Check(attrs); + auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr(op); } -std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { - VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); - VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); - AttributeMap attrs; - for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); - } - - return CreateOp(op_desc.type(), inputs, outputs, attrs); -} - -OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( +static VariableNameMap ConvertOpDescVarsToVarNameMap( const google::protobuf::RepeatedPtrField& op_desc_vars) { - VarNameMap ret_val; + VariableNameMap ret_val; for (auto& var : op_desc_vars) { auto& var_names = ret_val[var.parameter()]; auto& var_names_in_proto = var.arguments(); @@ -55,6 +41,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( return ret_val; } +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { + VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); + VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); + AttributeMap attrs; + for (auto& attr : op_desc.attrs()) { + attrs[attr.name()] = GetAttrValue(attr); + } + + return CreateOp(op_desc.type(), inputs, outputs, attrs); +} + std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); return std::unique_ptr(BuildGradOp(&op)); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4c2d13d639005d2d2710c19f63988333d89bce13..2d09cde41e3f5086279f9441e0fdc52549bed5ab 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" +#include "paddle/framework/op_info.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" @@ -30,28 +31,16 @@ namespace paddle { namespace framework { class OpRegistry { - using VarNameMap = OperatorBase::VarNameMap; - using OpCreator = std::function; - public: - struct OpInfo { - OpCreator creator_; - std::string grad_op_type_; - OpProto* proto_; - OpAttrChecker* checker_; - }; - template static void RegisterOp(const std::string& op_type, const std::string& grad_op_type) { - PADDLE_ENFORCE(op_info_map().count(op_type) == 0, + PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); OpInfo op_info; - op_info.creator_ = [](const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, - const AttributeMap& attrs) { + op_info.creator_ = []( + const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) { return new OpType(type, inputs, outputs, attrs); }; op_info.grad_op_type_ = grad_op_type; @@ -70,7 +59,7 @@ class OpRegistry { op_info.proto_ = nullptr; op_info.checker_ = nullptr; } - op_info_map().insert(std::make_pair(op_type, op_info)); + OpInfoMap::Instance().Insert(op_type, op_info); // register gradient op if (!grad_op_type.empty()) { RegisterOp(grad_op_type, ""); @@ -78,21 +67,13 @@ class OpRegistry { } static std::unique_ptr CreateOp(const std::string& type, - const VarNameMap& inputs, - const VarNameMap& outputs, + const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); - static VarNameMap ConvertOpDescVarsToVarNameMap( - const google::protobuf::RepeatedPtrField& op_desc_vars); - static std::unique_ptr CreateGradOp(const OperatorBase& op); - - static std::unordered_map& op_info_map() { - static std::unordered_map op_info_map_; - return op_info_map_; - } }; class Registrar { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index eadd8f3316ff1ebffb94a56b2e62d661e4e0b38f..7abbde610f1e9c530393b9a9cabe40b826712212 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name, } OperatorBase::OperatorBase(const std::string& type, - const OperatorBase::VarNameMap& inputs, - const OperatorBase::VarNameMap& outputs, + const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { static std::atomic gUniqId(0UL); @@ -141,18 +141,10 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { } return ret_val; } - auto it = OpRegistry::op_info_map().find(type_); - PADDLE_ENFORCE( - it != OpRegistry::op_info_map().end(), - "Operator %s not registered, cannot figure out intermediate outputs", - type_); - PADDLE_ENFORCE( - it->second.proto_ != nullptr, - "Operator %s has no OpProto, cannot figure out intermediate outputs", - type_); + auto& info = OpInfoMap::Instance().Get(Type()); // get all OpProto::Var for outputs - for (auto& o : it->second.proto_->outputs()) { + for (auto& o : info.Proto().outputs()) { // ignore all intermediate output if (o.intermediate()) continue; auto out = outputs_.find(o.name()); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 807298088981b969622174be753ea0da72067243..8397570d26f06f0238e9c5afc85d721df7679257 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include +#include "op_info.h" #include "paddle/framework/attribute.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/scope.h" @@ -62,10 +63,8 @@ class ExecutionContext; */ class OperatorBase { public: - using VarNameMap = std::map>; - - OperatorBase(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs); + OperatorBase(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs); virtual ~OperatorBase() {} @@ -93,8 +92,8 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); - const VarNameMap& Inputs() const { return inputs_; } - const VarNameMap& Outputs() const { return outputs_; } + const VariableNameMap& Inputs() const { return inputs_; } + const VariableNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -122,30 +121,32 @@ class OperatorBase { // I (Inputs)opear // O (Outputs) // OG (Output Gradients) - VarNameMap inputs_; + VariableNameMap inputs_; // NOTE: in case of OpGrad, outputs_ contains // IG (Inputs Gradients) - VarNameMap outputs_; + VariableNameMap outputs_; AttributeMap attrs_; }; // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(CLS) \ +#define DEFINE_OP_CLONE_METHOD(cls) \ std::unique_ptr Clone() const final { \ - return std::unique_ptr(new CLS(*this)); \ + return std::unique_ptr(new cls(*this)); \ } // Macro for define a default constructor for Operator. // You can also use // using PARENT_CLASS::PARENT_CLASS; // to use parent's constructor. -#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ - CLS(const std::string& type, const VarNameMap& inputs, \ - const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ - : PARENT_CLS(type, inputs, outputs, attrs) {} +#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \ + cls(const std::string& type, \ + const ::paddle::framework::VariableNameMap& inputs, \ + const ::paddle::framework::VariableNameMap& outputs, \ + const paddle::framework::AttributeMap& attrs) \ + : parent_cls(type, inputs, outputs, attrs) {} class NOP : public OperatorBase { public: @@ -389,8 +390,8 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - OperatorWithKernel(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) + OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} void InferShape(const Scope& scope) const override { diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 2425b87779f6af01b0e8a91b5f574a28385f0efd..1d7efb7b9403f7c1c6bdbb27a0258f79ae032f43 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,8 +23,8 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: - OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) + OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, const AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs), x(1) {} void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, @@ -249,8 +249,9 @@ TEST(OpKernel, multi_inputs) { class OperatorClone : public paddle::framework::OperatorBase { public: DEFINE_OP_CLONE_METHOD(OperatorClone); - OperatorClone(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, + OperatorClone(const std::string& type, + const paddle::framework::VariableNameMap& inputs, + const paddle::framework::VariableNameMap& outputs, const paddle::framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} void InferShape(const paddle::framework::Scope& scope) const override {} diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 4539a1903eb430eb0d76a787adb32984342a468d..6c619d660091dbd3b5cf38bf3870b65544f1c9d1 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -139,19 +139,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &op_info_map = OpRegistry::op_info_map(); std::vector ret_values; - for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) { - const OpProto *proto = it->second.proto_; - if (proto == nullptr) { - continue; - } - PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized"); + + OpInfoMap::Instance().IterAllInfo([&ret_values](const std::string &type, + const OpInfo &info) { + if (!info.HasOpProtoAndChecker()) return; std::string str; - PADDLE_ENFORCE(proto->SerializeToString(&str), + PADDLE_ENFORCE(info.Proto().SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); - ret_values.push_back(py::bytes(str)); - } + ret_values.emplace_back(str); + }); return ret_values; }); m.def_submodule( diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index a7d710511093dfbe13a13b1222b0230bba0398bd..9bfa712d986a386c14aef7c1a16b5bc5ff11a27f 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -81,9 +81,8 @@ std::vector NetOp::OutputVars(bool has_intermediate) const { return ret_val; } -NetOp::NetOp(const std::string& type, - const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, +NetOp::NetOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : framework::OperatorBase(type, inputs, outputs, attrs) {} diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 3d3f996ef52b6c1136425ca9de0f60e7e155458f..fcd8134b2c19cae6a4d006a4cd6fe32d2d627c34 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -38,8 +38,10 @@ class NetOp : public framework::OperatorBase { public: static const char kAll[]; NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {} - NetOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const framework::AttributeMap& attrs); + + NetOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) { this->ops_.reserve(o.ops_.size()); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 78ce0ba3c0fa4fe380e49a848c2434fe593cd00b..16bd249cb3d989c695ec9378f09d48833d70be58 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -131,8 +131,8 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{ "memories", "pre_memories", "boot_memories@grad"}; RecurrentOp::RecurrentOp(const std::string& type, - const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); @@ -223,8 +223,8 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { } RecurrentGradientOp::RecurrentGradientOp( - const std::string& type, const framework::OperatorBase::VarNameMap& inputs, - const framework::OperatorBase::VarNameMap& outputs, + const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { rnn::InitArgument(kArgName, &arg_, *this); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index bcfa817de8242153b164fa091309f19a6ad8a246..1033d657a3a8f96c8b3dae8dd93d3f1f6840b59b 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -114,8 +114,9 @@ class RecurrentGradientAlgorithm { class RecurrentOp : public framework::OperatorBase { public: - RecurrentOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const framework::AttributeMap& attrs); + RecurrentOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs); RecurrentOp(const RecurrentOp& o) : framework::OperatorBase( @@ -150,8 +151,9 @@ class RecurrentOp : public framework::OperatorBase { class RecurrentGradientOp : public framework::OperatorBase { public: - RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, + RecurrentGradientOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, const framework::AttributeMap& attrs); RecurrentGradientOp(const RecurrentGradientOp& o)