diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 48f77a6784b1c993a369ffe4f7544c7efe1b7de8..491ee21eec93c270cbe405e3bcbcb02b18af8fc7 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization. class OpRegistry { using OpCreator = std::function; + using VarIndexMap = std::unordered_map; public: template @@ -212,6 +213,17 @@ class OpRegistry { op_proto.IsInitialized(), "Fail to initialize %s's OpProto, because %s is not initialized", op_type, op_proto.InitializationErrorString()); + + VarIndexMaps()[op_type].reset(new VarIndexMap()); + auto& varmap = *VarIndexMaps()[op_type]; + int idx = 0; + for (auto& var : op_proto.inputs()) { + varmap[var.name()] = idx++; + } + idx = 0; + for (auto& var : op_proto.outputs()) { + varmap[var.name()] = idx++; + } } static OperatorPtr CreateOp(const OpDesc& op_desc) { @@ -220,7 +232,6 @@ class OpRegistry { OperatorPtr op(creators().at(op_type)()); //! Fill op's data member. Not use constructor because it will be noising //! for Op developer. - const OpProto& op_proto = protos().at(op_type); op->type_ = op_desc.type(); // set op's inputs_ from desc. op->inputs_.reserve((size_t)op_desc.inputs_size()); @@ -240,25 +251,31 @@ class OpRegistry { //! Convert Temporary variable name to an unique variable name. GenerateTempVariableName(op.get()); - // set argument offsets stored in op. - CreateInOutOffsetMap(op, op_proto); + //! set argument offsets stored in op. + { + auto var_index_it = VarIndexMaps().find(op_type); + if (var_index_it != VarIndexMaps().end()) { + op->in_out_idxs_ = var_index_it->second; + } + } //! Other op's custom Init for a complex Op. For simple Op, the Init //! method do nothing. op->Init(); return op; } - // init op.in_out_idxs_ to accelerate argument's offset lookup. - static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) { - op->CreateInOutOffsetMap(proto); - } - static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; }; private: + static std::unordered_map>& + VarIndexMaps() { + static std::unordered_map> maps_; + return maps_; + } + static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); for (auto& outname : op->outputs_) { diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 50cb2d936274dcc046d5641ff276aae77358d1bf..36479830535cdd49c93d965e6b68981012097b71 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -19,21 +19,10 @@ limitations under the License. */ namespace paddle { namespace framework { -void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) { - PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap"); - for (int i = 0; i < proto.inputs_size(); i++) { - const auto& name = proto.inputs()[i].name(); - in_out_idxs_[name] = i; - } - for (int i = 0; i < proto.outputs_size(); i++) { - const auto& name = proto.outputs()[i].name(); - in_out_idxs_[name] = i; - } -} - const std::string& OperatorBase::Input(const std::string& name) const { - auto it = in_out_idxs_.find(name); - PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); + auto it = in_out_idxs_->find(name); + PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", + name); if (attrs_.count("input_format") == 0) { return inputs_[it->second]; @@ -46,7 +35,7 @@ const std::string& OperatorBase::Input(const std::string& name) const { std::vector OperatorBase::Inputs(const std::string& name) const { auto input_format = GetAttr>("input_format"); - auto offset = in_out_idxs_.at(name); + auto offset = in_out_idxs_->at(name); return std::vector{ inputs_.begin() + input_format.at(offset), @@ -54,8 +43,9 @@ std::vector OperatorBase::Inputs(const std::string& name) const { } const std::string& OperatorBase::Output(const std::string& name) const { - auto it = in_out_idxs_.find(name); - PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); + auto it = in_out_idxs_->find(name); + PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", + name); if (attrs_.count("output_format") == 0) { return outputs_[it->second]; @@ -68,7 +58,7 @@ const std::string& OperatorBase::Output(const std::string& name) const { std::vector OperatorBase::Outputs(const std::string& name) const { auto output_format = GetAttr>("output_format"); - auto offset = in_out_idxs_.at(name); + auto offset = in_out_idxs_->at(name); return std::vector{ outputs_.begin() + output_format.at(offset), diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 2fe9670677c6041c0360d096b88e818676a8c929..2081b8a05c197f3fe1451f7e58d2e6f1748120a3 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -82,16 +82,13 @@ class OperatorBase { // TODO add a vector_view to prevent memory copy. std::vector Outputs(const std::string& name) const; - // init in_out_idxs_ to accelerate argument's offset lookup. - void CreateInOutOffsetMap(const OpProto& proto); - public: std::string type_; std::vector inputs_; std::vector outputs_; AttributeMap attrs_; // store the arguments' offset described in op_desc. - std::unordered_map in_out_idxs_; + std::shared_ptr> in_out_idxs_; }; class KernelContext {