提交 c1219a53 编写于 作者: Y Yu Yang

Change `in_out_idxs_` to shared_ptr

* `in_out_idxs_` shares between all operator instance in same type
  of operator.
上级 50fa7e63
...@@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization. ...@@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization.
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>; using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -212,6 +213,17 @@ class OpRegistry { ...@@ -212,6 +213,17 @@ class OpRegistry {
op_proto.IsInitialized(), op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString()); op_type, op_proto.InitializationErrorString());
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++;
}
} }
static OperatorPtr CreateOp(const OpDesc& op_desc) { static OperatorPtr CreateOp(const OpDesc& op_desc) {
...@@ -220,7 +232,6 @@ class OpRegistry { ...@@ -220,7 +232,6 @@ class OpRegistry {
OperatorPtr op(creators().at(op_type)()); OperatorPtr op(creators().at(op_type)());
//! Fill op's data member. Not use constructor because it will be noising //! Fill op's data member. Not use constructor because it will be noising
//! for Op developer. //! for Op developer.
const OpProto& op_proto = protos().at(op_type);
op->type_ = op_desc.type(); op->type_ = op_desc.type();
// set op's inputs_ from desc. // set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size()); op->inputs_.reserve((size_t)op_desc.inputs_size());
...@@ -240,25 +251,31 @@ class OpRegistry { ...@@ -240,25 +251,31 @@ class OpRegistry {
//! Convert Temporary variable name to an unique variable name. //! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName(op.get()); GenerateTempVariableName(op.get());
// set argument offsets stored in op. //! set argument offsets stored in op.
CreateInOutOffsetMap(op, op_proto); {
auto var_index_it = VarIndexMaps().find(op_type);
if (var_index_it != VarIndexMaps().end()) {
op->in_out_idxs_ = var_index_it->second;
}
}
//! Other op's custom Init for a complex Op. For simple Op, the Init //! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing. //! method do nothing.
op->Init(); op->Init();
return op; return op;
} }
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) {
op->CreateInOutOffsetMap(proto);
}
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpProto> protos_;
return protos_; return protos_;
}; };
private: private:
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& outname : op->outputs_) {
......
...@@ -19,21 +19,10 @@ limitations under the License. */ ...@@ -19,21 +19,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) {
PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap");
for (int i = 0; i < proto.inputs_size(); i++) {
const auto& name = proto.inputs()[i].name();
in_out_idxs_[name] = i;
}
for (int i = 0; i < proto.outputs_size(); i++) {
const auto& name = proto.outputs()[i].name();
in_out_idxs_[name] = i;
}
}
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto it = in_out_idxs_.find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name);
if (attrs_.count("input_format") == 0) { if (attrs_.count("input_format") == 0) {
return inputs_[it->second]; return inputs_[it->second];
...@@ -46,7 +35,7 @@ const std::string& OperatorBase::Input(const std::string& name) const { ...@@ -46,7 +35,7 @@ const std::string& OperatorBase::Input(const std::string& name) const {
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_.at(name); auto offset = in_out_idxs_->at(name);
return std::vector<std::string>{ return std::vector<std::string>{
inputs_.begin() + input_format.at(offset), inputs_.begin() + input_format.at(offset),
...@@ -54,8 +43,9 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -54,8 +43,9 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
} }
const std::string& OperatorBase::Output(const std::string& name) const { const std::string& OperatorBase::Output(const std::string& name) const {
auto it = in_out_idxs_.find(name); auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_.end(), "no key [%s] in in_out_idxs_", name); PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name);
if (attrs_.count("output_format") == 0) { if (attrs_.count("output_format") == 0) {
return outputs_[it->second]; return outputs_[it->second];
...@@ -68,7 +58,7 @@ const std::string& OperatorBase::Output(const std::string& name) const { ...@@ -68,7 +58,7 @@ const std::string& OperatorBase::Output(const std::string& name) const {
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_.at(name); auto offset = in_out_idxs_->at(name);
return std::vector<std::string>{ return std::vector<std::string>{
outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset),
......
...@@ -82,16 +82,13 @@ class OperatorBase { ...@@ -82,16 +82,13 @@ class OperatorBase {
// TODO add a vector_view to prevent memory copy. // TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const; std::vector<std::string> Outputs(const std::string& name) const;
// init in_out_idxs_ to accelerate argument's offset lookup.
void CreateInOutOffsetMap(const OpProto& proto);
public: public:
std::string type_; std::string type_;
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// store the arguments' offset described in op_desc. // store the arguments' offset described in op_desc.
std::unordered_map<std::string, int> in_out_idxs_; std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
}; };
class KernelContext { class KernelContext {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册